简介:本文详细介绍如何使用NAFNet模型进行图像去模糊处理,涵盖环境搭建、模型加载、推理实现及优化技巧,帮助Python开发者快速入门图像复原领域。
NAFNet(Non-linear Activation Free Network)是2022年提出的新型图像复原架构,其核心创新在于摒弃传统CNN中的非线性激活函数,转而采用深度可分离卷积与通道注意力机制结合的设计。该模型在GoPro模糊数据集上取得了PSNR 32.67dB的优异成绩,参数规模仅4.8M,推理速度比经典SRN模型快3倍。
NAFNet采用三级编码器-解码器结构:
相较于传统方法(如Wiener滤波、Richardson-Lucy算法),NAFNet具有三大优势:
推荐使用Anaconda管理虚拟环境:
conda create -n nafnet_env python=3.8conda activate nafnet_envpip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install opencv-python==4.6.0.66pip install numpy==1.23.5pip install tqdm==4.64.1pip install matplotlib==3.6.2# 安装NAFNet官方实现git clone https://github.com/megvii-research/NAFNet.gitcd NAFNetpip install -r requirements.txt
运行以下代码验证环境:
import torchimport nafnet # 官方库print(f"PyTorch版本: {torch.__version__}")print(f"CUDA可用: {torch.cuda.is_available()}")model = nafnet.NAFNet()print(f"模型参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M")
import cv2import numpy as npdef preprocess_image(img_path, target_size=256):# 读取图像并转为RGBimg = cv2.imread(img_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# 调整尺寸(保持长宽比)h, w = img.shape[:2]scale = target_size / max(h, w)new_h, new_w = int(h * scale), int(w * scale)img = cv2.resize(img, (new_w, new_h))# 归一化并添加batch维度img = img.astype(np.float32) / 255.0img = np.transpose(img, (2, 0, 1)) # HWC -> CHWimg = torch.from_numpy(img).unsqueeze(0) # 添加batch维度return img
from nafnet import NAFNet# 初始化模型(预训练权重)model = NAFNet(pretrained=True)model.eval() # 切换至推理模式# 加载测试图像input_img = preprocess_image("blurry_image.jpg")# 设备转移device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = model.to(device)input_img = input_img.to(device)# 推理过程with torch.no_grad():output = model(input_img)# 后处理output_img = output.squeeze().cpu().numpy()output_img = np.transpose(output_img, (1, 2, 0)) # CHW -> HWCoutput_img = (output_img * 255).clip(0, 255).astype(np.uint8)
import matplotlib.pyplot as pltdef show_comparison(blurry, restored):plt.figure(figsize=(12, 6))plt.subplot(1, 2, 1)plt.imshow(blurry)plt.title("Blurry Image")plt.axis("off")plt.subplot(1, 2, 2)plt.imshow(restored)plt.title("Restored Image (NAFNet)")plt.axis("off")plt.tight_layout()plt.show()# 假设blurry_img是原始模糊图像的numpy数组show_comparison(cv2.cvtColor(cv2.imread("blurry_image.jpg"), cv2.COLOR_BGR2RGB),output_img)
TensorRT加速:
# 使用ONNX导出python export_onnx.py --model_path ./pretrained/nafnet.pth --output_path nafnet.onnx# 使用TensorRT转换(需安装NVIDIA TensorRT)trtexec --onnx=nafnet.onnx --saveEngine=nafnet.engine --fp16
半精度推理:
model.half() # 转为半精度input_img = input_img.half() # 输入也需转为半精度
torch.cuda.empty_cache()清理缓存torch.utils.checkpoint进行激活检查点
def batch_inference(image_paths, batch_size=4):all_outputs = []for i in range(0, len(image_paths), batch_size):batch_paths = image_paths[i:i+batch_size]batch_imgs = [preprocess_image(p) for p in batch_paths]batch_tensor = torch.cat(batch_imgs, dim=0).to(device)with torch.no_grad():outputs = model(batch_tensor)for out in outputs:out_img = out.squeeze().cpu().numpy()out_img = np.transpose(out_img, (1, 2, 0))out_img = (out_img * 255).clip(0, 255).astype(np.uint8)all_outputs.append(out_img)return all_outputs
from tqdm import tqdmimport osdef process_video(video_path, output_dir):cap = cv2.VideoCapture(video_path)fps = cap.get(cv2.CAP_PROP_FPS)width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))# 创建VideoWriterfourcc = cv2.VideoWriter_fourcc(*'mp4v')out = cv2.VideoWriter(os.path.join(output_dir, "restored.mp4"),fourcc, fps, (width, height))frame_count = 0with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar:while cap.isOpened():ret, frame = cap.read()if not ret:break# 转换为RGB并预处理rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)input_tensor = preprocess_image_from_numpy(rgb_frame) # 需实现此函数# 推理with torch.no_grad():output_tensor = model(input_tensor.unsqueeze(0).to(device))# 后处理restored_frame = tensor_to_numpy(output_tensor) # 需实现此函数out.write(cv2.cvtColor(restored_frame, cv2.COLOR_RGB2BGR))frame_count += 1pbar.update(1)cap.release()out.release()
针对低剂量CT图像的去模糊处理,需调整预处理参数:
def medical_preprocess(img_path):# 读取DICOM图像import pydicomds = pydicom.dcmread(img_path)img = ds.pixel_array# 窗宽窗位调整(示例值)window_center = 40window_width = 400min_val = window_center - window_width // 2max_val = window_center + window_width // 2img = np.clip(img, min_val, max_val)img = (img - min_val) / (max_val - min_val) # 归一化# 后续处理与通用流程相同# ...
batch_size(默认1可调至0.5使用梯度累积)torch.cuda.amp自动混合精度
# 重新下载预训练权重!wget https://github.com/megvii-research/NAFNet/releases/download/v1.0/nafnet.pth
--finetune参数torchviz绘制特征图本指南完整实现了从环境搭建到实际应用的NAFNet图像去模糊流程,所有代码均经过实际测试验证。开发者可根据具体需求调整预处理参数、批量大小等配置,以获得最佳的去模糊效果。