基于PyTorch的图像分类实战:完整代码与深度解析

作者:十万个为什么2025.10.12 01:01浏览量:2

简介:本文提供基于PyTorch的图像分类完整实现,包含数据加载、模型构建、训练与评估全流程代码及详细注释,适合开发者快速掌握深度学习图像分类技术。

基于PyTorch的图像分类实战:完整代码与深度解析

引言

图像分类是计算机视觉领域的核心任务,PyTorch作为主流深度学习框架,凭借其动态计算图和Pythonic接口,成为开发者实现图像分类的首选工具。本文将通过一个完整的实战案例,展示如何使用PyTorch实现从数据加载到模型部署的全流程,并提供详细注释帮助读者理解每个步骤的实现原理。

一、环境准备与数据集选择

1.1 环境配置

首先需要安装PyTorch及相关依赖库:

  1. pip install torch torchvision matplotlib numpy

建议使用GPU环境加速训练,可通过torch.cuda.is_available()检查CUDA是否可用。

1.2 数据集选择

本文使用CIFAR-10数据集,该数据集包含10个类别的60000张32x32彩色图像,分为50000张训练集和10000张测试集。PyTorch的torchvision.datasets模块提供了便捷的数据加载接口:

  1. import torchvision
  2. import torchvision.transforms as transforms
  3. # 定义数据预处理流程
  4. transform = transforms.Compose([
  5. transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0,1]
  6. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]
  7. ])
  8. # 加载训练集和测试集
  9. trainset = torchvision.datasets.CIFAR10(
  10. root='./data', train=True, download=True, transform=transform)
  11. trainloader = torch.utils.data.DataLoader(
  12. trainset, batch_size=32, shuffle=True, num_workers=2)
  13. testset = torchvision.datasets.CIFAR10(
  14. root='./data', train=False, download=True, transform=transform)
  15. testloader = torch.utils.data.DataLoader(
  16. testset, batch_size=32, shuffle=False, num_workers=2)
  17. # 定义类别标签
  18. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  19. 'dog', 'frog', 'horse', 'ship', 'truck')

关键点解析

  • transforms.Compose:将多个数据预处理操作组合成一个流水线
  • Normalize:使用均值和标准差进行标准化,这里采用(0.5,0.5,0.5)的简单缩放
  • DataLoader:实现批量加载、数据打乱和多线程加速

二、CNN模型构建

2.1 基础CNN模型实现

我们构建一个包含3个卷积层和2个全连接层的CNN模型:

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class CNN(nn.Module):
  4. def __init__(self):
  5. super(CNN, self).__init__()
  6. # 卷积层1:输入通道3,输出通道32,3x3卷积核
  7. self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
  8. # 卷积层2:输入通道32,输出通道64,3x3卷积核
  9. self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  10. # 卷积层3:输入通道64,输出通道128,3x3卷积核
  11. self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
  12. # 最大池化层
  13. self.pool = nn.MaxPool2d(2, 2)
  14. # 全连接层
  15. self.fc1 = nn.Linear(128 * 4 * 4, 512) # 经过3次池化后,特征图尺寸为4x4
  16. self.fc2 = nn.Linear(512, 10) # 输出10个类别
  17. def forward(self, x):
  18. # 第一层卷积+池化+ReLU
  19. x = self.pool(F.relu(self.conv1(x)))
  20. # 第二层卷积+池化+ReLU
  21. x = self.pool(F.relu(self.conv2(x)))
  22. # 第三层卷积+池化+ReLU
  23. x = self.pool(F.relu(self.conv3(x)))
  24. # 展平特征图
  25. x = x.view(-1, 128 * 4 * 4)
  26. # 全连接层
  27. x = F.relu(self.fc1(x))
  28. x = self.fc2(x)
  29. return x

模型结构解析

  1. 输入层:32x32x3的RGB图像
  2. 卷积块1:32个3x3卷积核 → ReLU激活 → 2x2最大池化 → 输出尺寸16x16x32
  3. 卷积块2:64个3x3卷积核 → ReLU激活 → 2x2最大池化 → 输出尺寸8x8x64
  4. 卷积块3:128个3x3卷积核 → ReLU激活 → 2x2最大池化 → 输出尺寸4x4x128
  5. 全连接层:展平为2048维向量 → 512维隐藏层 → 10维输出

2.2 模型初始化与设备分配

  1. import torch
  2. # 初始化模型
  3. model = CNN()
  4. # 将模型移动到GPU(如果可用)
  5. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  6. model.to(device)

三、训练流程实现

3.1 定义损失函数和优化器

  1. import torch.optim as optim
  2. # 交叉熵损失函数
  3. criterion = nn.CrossEntropyLoss()
  4. # 随机梯度下降优化器,学习率0.001,动量0.9
  5. optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

3.2 训练循环实现

  1. def train_model(model, trainloader, criterion, optimizer, epochs=10):
  2. for epoch in range(epochs):
  3. running_loss = 0.0
  4. correct = 0
  5. total = 0
  6. for i, data in enumerate(trainloader, 0):
  7. # 获取输入和标签
  8. inputs, labels = data[0].to(device), data[1].to(device)
  9. # 梯度清零
  10. optimizer.zero_grad()
  11. # 前向传播
  12. outputs = model(inputs)
  13. # 计算损失
  14. loss = criterion(outputs, labels)
  15. # 反向传播和优化
  16. loss.backward()
  17. optimizer.step()
  18. # 统计信息
  19. running_loss += loss.item()
  20. _, predicted = torch.max(outputs.data, 1)
  21. total += labels.size(0)
  22. correct += (predicted == labels).sum().item()
  23. # 每200个batch打印一次统计信息
  24. if i % 200 == 199:
  25. print(f'Epoch {epoch+1}, Batch {i+1}, '
  26. f'Loss: {running_loss/200:.3f}, '
  27. f'Acc: {100*correct/total:.2f}%')
  28. running_loss = 0.0
  29. # 每个epoch结束后打印验证准确率
  30. test_acc = evaluate_model(model, testloader)
  31. print(f'Epoch {epoch+1} completed, Test Acc: {test_acc:.2f}%')
  32. def evaluate_model(model, testloader):
  33. correct = 0
  34. total = 0
  35. with torch.no_grad():
  36. for data in testloader:
  37. images, labels = data[0].to(device), data[1].to(device)
  38. outputs = model(images)
  39. _, predicted = torch.max(outputs.data, 1)
  40. total += labels.size(0)
  41. correct += (predicted == labels).sum().item()
  42. return 100 * correct / total

训练流程解析

  1. 外层循环控制训练轮次(epochs)
  2. 内层循环遍历每个batch的数据
  3. 关键步骤:梯度清零→前向传播→计算损失→反向传播→参数更新
  4. 统计训练损失和准确率
  5. 每个epoch结束后在测试集上评估模型性能

四、完整代码与运行示例

4.1 完整实现代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. import torchvision
  6. import torchvision.transforms as transforms
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. # 设备配置
  10. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  11. # 1. 数据加载
  12. transform = transforms.Compose([
  13. transforms.ToTensor(),
  14. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  15. ])
  16. trainset = torchvision.datasets.CIFAR10(
  17. root='./data', train=True, download=True, transform=transform)
  18. trainloader = torch.utils.data.DataLoader(
  19. trainset, batch_size=32, shuffle=True, num_workers=2)
  20. testset = torchvision.datasets.CIFAR10(
  21. root='./data', train=False, download=True, transform=transform)
  22. testloader = torch.utils.data.DataLoader(
  23. testset, batch_size=32, shuffle=False, num_workers=2)
  24. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  25. 'dog', 'frog', 'horse', 'ship', 'truck')
  26. # 2. 模型定义
  27. class CNN(nn.Module):
  28. def __init__(self):
  29. super(CNN, self).__init__()
  30. self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
  31. self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  32. self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
  33. self.pool = nn.MaxPool2d(2, 2)
  34. self.fc1 = nn.Linear(128 * 4 * 4, 512)
  35. self.fc2 = nn.Linear(512, 10)
  36. def forward(self, x):
  37. x = self.pool(F.relu(self.conv1(x)))
  38. x = self.pool(F.relu(self.conv2(x)))
  39. x = self.pool(F.relu(self.conv3(x)))
  40. x = x.view(-1, 128 * 4 * 4)
  41. x = F.relu(self.fc1(x))
  42. x = self.fc2(x)
  43. return x
  44. model = CNN().to(device)
  45. # 3. 训练配置
  46. criterion = nn.CrossEntropyLoss()
  47. optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  48. # 4. 训练函数
  49. def train_model(model, trainloader, criterion, optimizer, epochs=10):
  50. for epoch in range(epochs):
  51. running_loss = 0.0
  52. correct = 0
  53. total = 0
  54. for i, data in enumerate(trainloader, 0):
  55. inputs, labels = data[0].to(device), data[1].to(device)
  56. optimizer.zero_grad()
  57. outputs = model(inputs)
  58. loss = criterion(outputs, labels)
  59. loss.backward()
  60. optimizer.step()
  61. running_loss += loss.item()
  62. _, predicted = torch.max(outputs.data, 1)
  63. total += labels.size(0)
  64. correct += (predicted == labels).sum().item()
  65. if i % 200 == 199:
  66. print(f'Epoch {epoch+1}, Batch {i+1}, '
  67. f'Loss: {running_loss/200:.3f}, '
  68. f'Acc: {100*correct/total:.2f}%')
  69. running_loss = 0.0
  70. test_acc = evaluate_model(model, testloader)
  71. print(f'Epoch {epoch+1} completed, Test Acc: {test_acc:.2f}%')
  72. def evaluate_model(model, testloader):
  73. correct = 0
  74. total = 0
  75. with torch.no_grad():
  76. for data in testloader:
  77. images, labels = data[0].to(device), data[1].to(device)
  78. outputs = model(images)
  79. _, predicted = torch.max(outputs.data, 1)
  80. total += labels.size(0)
  81. correct += (predicted == labels).sum().item()
  82. return 100 * correct / total
  83. # 5. 执行训练
  84. train_model(model, trainloader, criterion, optimizer, epochs=10)
  85. # 6. 保存模型
  86. torch.save(model.state_dict(), 'cifar_cnn.pth')

4.2 运行结果与优化建议

典型运行结果:

  1. Epoch 1, Batch 200, Loss: 2.034, Acc: 32.45%
  2. Epoch 1, Batch 400, Loss: 1.876, Acc: 36.78%
  3. ...
  4. Epoch 10 completed, Test Acc: 72.34%

优化建议

  1. 学习率调整:使用学习率调度器(如torch.optim.lr_scheduler.StepLR
  2. 数据增强:添加随机裁剪、水平翻转等增强操作
  3. 模型改进:尝试ResNet等更先进的架构
  4. 超参数调优:使用网格搜索或贝叶斯优化

五、总结与扩展

本文完整实现了基于PyTorch的图像分类系统,涵盖了从数据加载到模型部署的全流程。关键技术点包括:

  1. 使用torchvision高效加载和预处理数据
  2. 构建包含卷积层、池化层和全连接层的CNN模型
  3. 实现完整的训练循环和评估流程
  4. 提供GPU加速支持和模型保存功能

扩展方向

  • 尝试迁移学习,使用预训练模型(如ResNet)进行微调
  • 实现更复杂的数据增强策略
  • 部署模型到移动端或Web服务
  • 探索更先进的模型架构(如EfficientNet、Vision Transformer)

通过本文的实践,读者可以掌握PyTorch实现图像分类的核心技术,为后续更复杂的计算机视觉任务打下坚实基础。