简介:本文提供基于PyTorch的图像分类完整实现,包含数据加载、模型构建、训练与评估全流程代码及详细注释,适合开发者快速掌握深度学习图像分类技术。
图像分类是计算机视觉领域的核心任务,PyTorch作为主流深度学习框架,凭借其动态计算图和Pythonic接口,成为开发者实现图像分类的首选工具。本文将通过一个完整的实战案例,展示如何使用PyTorch实现从数据加载到模型部署的全流程,并提供详细注释帮助读者理解每个步骤的实现原理。
首先需要安装PyTorch及相关依赖库:
pip install torch torchvision matplotlib numpy
建议使用GPU环境加速训练,可通过torch.cuda.is_available()检查CUDA是否可用。
本文使用CIFAR-10数据集,该数据集包含10个类别的60000张32x32彩色图像,分为50000张训练集和10000张测试集。PyTorch的torchvision.datasets模块提供了便捷的数据加载接口:
import torchvisionimport torchvision.transforms as transforms# 定义数据预处理流程transform = transforms.Compose([transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]])# 加载训练集和测试集trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)# 定义类别标签classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
关键点解析:
transforms.Compose:将多个数据预处理操作组合成一个流水线Normalize:使用均值和标准差进行标准化,这里采用(0.5,0.5,0.5)的简单缩放DataLoader:实现批量加载、数据打乱和多线程加速我们构建一个包含3个卷积层和2个全连接层的CNN模型:
import torch.nn as nnimport torch.nn.functional as Fclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层1:输入通道3,输出通道32,3x3卷积核self.conv1 = nn.Conv2d(3, 32, 3, padding=1)# 卷积层2:输入通道32,输出通道64,3x3卷积核self.conv2 = nn.Conv2d(32, 64, 3, padding=1)# 卷积层3:输入通道64,输出通道128,3x3卷积核self.conv3 = nn.Conv2d(64, 128, 3, padding=1)# 最大池化层self.pool = nn.MaxPool2d(2, 2)# 全连接层self.fc1 = nn.Linear(128 * 4 * 4, 512) # 经过3次池化后,特征图尺寸为4x4self.fc2 = nn.Linear(512, 10) # 输出10个类别def forward(self, x):# 第一层卷积+池化+ReLUx = self.pool(F.relu(self.conv1(x)))# 第二层卷积+池化+ReLUx = self.pool(F.relu(self.conv2(x)))# 第三层卷积+池化+ReLUx = self.pool(F.relu(self.conv3(x)))# 展平特征图x = x.view(-1, 128 * 4 * 4)# 全连接层x = F.relu(self.fc1(x))x = self.fc2(x)return x
模型结构解析:
import torch# 初始化模型model = CNN()# 将模型移动到GPU(如果可用)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)
import torch.optim as optim# 交叉熵损失函数criterion = nn.CrossEntropyLoss()# 随机梯度下降优化器,学习率0.001,动量0.9optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
def train_model(model, trainloader, criterion, optimizer, epochs=10):for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):# 获取输入和标签inputs, labels = data[0].to(device), data[1].to(device)# 梯度清零optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()# 统计信息running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 每200个batch打印一次统计信息if i % 200 == 199:print(f'Epoch {epoch+1}, Batch {i+1}, 'f'Loss: {running_loss/200:.3f}, 'f'Acc: {100*correct/total:.2f}%')running_loss = 0.0# 每个epoch结束后打印验证准确率test_acc = evaluate_model(model, testloader)print(f'Epoch {epoch+1} completed, Test Acc: {test_acc:.2f}%')def evaluate_model(model, testloader):correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return 100 * correct / total
训练流程解析:
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as np# 设备配置device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 1. 数据加载transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')# 2. 模型定义class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(128 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))x = x.view(-1, 128 * 4 * 4)x = F.relu(self.fc1(x))x = self.fc2(x)return xmodel = CNN().to(device)# 3. 训练配置criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 4. 训练函数def train_model(model, trainloader, criterion, optimizer, epochs=10):for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if i % 200 == 199:print(f'Epoch {epoch+1}, Batch {i+1}, 'f'Loss: {running_loss/200:.3f}, 'f'Acc: {100*correct/total:.2f}%')running_loss = 0.0test_acc = evaluate_model(model, testloader)print(f'Epoch {epoch+1} completed, Test Acc: {test_acc:.2f}%')def evaluate_model(model, testloader):correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return 100 * correct / total# 5. 执行训练train_model(model, trainloader, criterion, optimizer, epochs=10)# 6. 保存模型torch.save(model.state_dict(), 'cifar_cnn.pth')
典型运行结果:
Epoch 1, Batch 200, Loss: 2.034, Acc: 32.45%Epoch 1, Batch 400, Loss: 1.876, Acc: 36.78%...Epoch 10 completed, Test Acc: 72.34%
优化建议:
torch.optim.lr_scheduler.StepLR)本文完整实现了基于PyTorch的图像分类系统,涵盖了从数据加载到模型部署的全流程。关键技术点包括:
torchvision高效加载和预处理数据扩展方向:
通过本文的实践,读者可以掌握PyTorch实现图像分类的核心技术,为后续更复杂的计算机视觉任务打下坚实基础。