简介:本文详细介绍如何使用UNet模型进行图像语义分割,从自制数据集的标注与预处理,到模型训练与推理测试,提供完整的代码实现与操作指南,适合初学者与进阶开发者。
UNet(U-shaped Network)由Ronneberger等人在2015年提出,是一种专为医学图像分割设计的全卷积网络(FCN)。其核心思想是通过编码器-解码器结构和跳跃连接实现特征的高效提取与空间信息的保留,尤其适用于小数据集与高分辨率图像。相较于其他模型,UNet具有以下优势:
本文将围绕UNet展开,从数据集制作到模型训练与推理,提供一套完整的解决方案。
步骤1:图像采集
步骤2:图像标准化
import cv2import numpy as npdef preprocess_image(image_path, target_size=(512, 512)):image = cv2.imread(image_path)image = cv2.resize(image, target_size)image = image.astype(np.float32) / 255.0 # 归一化return image
工具选择:
标注流程:
def json_to_mask(json_path, output_path, image_shape=(512, 512)):
with open(json_path) as f:
data = json.load(f)
mask = Image.new('L', image_shape, 0) # 'L'表示灰度图draw = ImageDraw.Draw(mask)for shape in data['shapes']:points = shape['points']if shape['shape_type'] == 'polygon':draw.polygon(points, fill=255) # 填充为白色(255)mask.save(output_path)
**数据集结构**:
dataset/
├── images/
│ ├── img1.jpg
│ └── img2.jpg
└── masks/
├── img1_mask.png
└── img2_mask.png
# 三、UNet模型实现:从编码器到解码器## 1. 模型架构解析UNet的核心是**对称的U型结构**,包含:- **编码器(下采样)**:通过卷积与池化提取高级特征。- **解码器(上采样)**:通过转置卷积恢复空间分辨率。- **跳跃连接**:将编码器的特征图与解码器的上采样结果拼接。## 2. 代码实现(PyTorch)```pythonimport torchimport torch.nn as nnimport torch.nn.functional as Fclass DoubleConv(nn.Module):"""双卷积块:Conv2d + ReLU + Conv2d + ReLU"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class UNet(nn.Module):def __init__(self, n_classes=1):super().__init__()# 编码器self.dconv1 = DoubleConv(3, 64)self.dconv2 = DoubleConv(64, 128)self.dconv3 = DoubleConv(128, 256)self.dconv4 = DoubleConv(256, 512)# 解码器self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)# 输出层self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)def forward(self, x):# 编码器conv1 = self.dconv1(x)pool1 = F.max_pool2d(conv1, 2)conv2 = self.dconv2(pool1)pool2 = F.max_pool2d(conv2, 2)conv3 = self.dconv3(pool2)pool3 = F.max_pool2d(conv3, 2)conv4 = self.dconv4(pool3)# 解码器 + 跳跃连接up3 = self.upconv3(conv4)up3 = torch.cat([up3, conv3], dim=1) # 拼接特征图up2 = self.upconv2(up3)up2 = torch.cat([up2, conv2], dim=1)up1 = self.upconv1(up2)up1 = torch.cat([up1, conv1], dim=1)# 输出out = self.final_conv(up1)return out
from torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsclass CustomDataset(Dataset):def __init__(self, image_paths, mask_paths, transform=None):self.image_paths = image_pathsself.mask_paths = mask_pathsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = cv2.imread(self.image_paths[idx])image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)if self.transform:image = self.transform(image)mask = self.transform(mask)return image, mask# 示例:创建DataLoadertransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])dataset = CustomDataset(image_paths, mask_paths, transform=transform)dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
import torch.optim as optimfrom tqdm import tqdmdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = UNet(n_classes=1).to(device)criterion = nn.BCEWithLogitsLoss() # 二分类任务optimizer = optim.Adam(model.parameters(), lr=1e-4)def train_model(model, dataloader, epochs=50):model.train()for epoch in range(epochs):running_loss = 0.0for images, masks in tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}'):images = images.to(device)masks = masks.float().unsqueeze(1).to(device) # 增加通道维度optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.4f}')train_model(model, dataloader)
# 保存模型torch.save(model.state_dict(), 'unet_model.pth')# 加载模型model = UNet(n_classes=1).to(device)model.load_state_dict(torch.load('unet_model.pth'))model.eval()
import matplotlib.pyplot as pltdef predict_and_visualize(model, image_path, mask_path=None):image = preprocess_image(image_path)image_tensor = transforms.ToTensor()(image).unsqueeze(0).to(device)with torch.no_grad():output = model(image_tensor)pred_mask = torch.sigmoid(output).squeeze().cpu().numpy()# 可视化plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.imshow(image)plt.title('Original Image')plt.subplot(1, 2, 2)plt.imshow(pred_mask, cmap='gray')plt.title('Predicted Mask')if mask_path:true_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)plt.figure(figsize=(5, 5))plt.imshow(true_mask, cmap='gray')plt.title('True Mask')plt.show()# 示例调用predict_and_visualize('test_image.jpg', 'test_mask.png')
通过本文的完整流程,读者可快速掌握UNet从数据集制作到模型推理的全过程,为实际项目提供技术支撑。