简介:本文将介绍如何在 PyTorch 中保存和加载模型,以及如何查看模型的结构。我们将从最基础的方法开始,不涉及优化器保存和部分参数加载等进阶方法。
PyTorch 提供了一系列简单易用的函数,用于保存和加载模型。以下是如何使用这些函数的基本步骤。
保存模型
在训练模型后,您可以使用 torch.save() 函数将模型保存到磁盘。
# 假设 model 是您已经训练好的模型torch.save(model.state_dict(), 'model_weights.pth')
在这里,model.state_dict() 返回一个包含模型参数的字典,而 'model_weights.pth' 是您要保存的文件名。
加载模型
加载模型的过程分为两步:首先加载模型的参数,然后使用这些参数来加载模型。
# 加载模型的参数model_weights = torch.load('model_weights.pth')# 假设 model 是您已经定义好的模型架构model = model(pretrained=False) # 假设您有一个预训练的模型架构,这里设置为 False 表示不使用预训练权重model.load_state_dict(model_weights)
在这里,torch.load() 函数用于加载保存的参数,然后使用 load_state_dict() 方法将这些参数应用到模型上。
查看模型结构
PyTorch 的 torch.nn 模块提供了一个方便的方法来查看神经网络的结构。您可以使用 print(model) 来打印模型的结构。对于更详细的输出,您可以使用 model.summary()(如果可用)。例如,对于 VGG16 这样的预训练模型:
```python
print(model) # 打印模型整体结构
model.summary() # 打印模型详细结构(如果可用)