简介:PyTorch 06: PyTorch保存和加载模型
PyTorch 06: PyTorch保存和加载模型
在深度学习中,模型的保存和加载是一个非常重要的步骤。这不仅可以在训练新模型时保存已有的训练成果,避免重复工作,还可以在需要时加载已经训练好的模型进行预测。PyTorch提供了非常方便的API来保存和加载模型。
一、保存模型
在PyTorch中,模型的保存通常使用torch.save()函数。torch.save()接受两个参数,一个是待保存的模型对象,另一个是保存的文件路径。以下是一个示例:
# 假设我们有一个已经训练好的模型 `model`model = ...# 保存模型到文件torch.save(model, 'model.pth')
这会将整个模型保存到model.pth文件中。保存模型时,还会自动保存模型的参数(即模型的权重和偏置)。这样,当需要重新加载模型时,可以快速恢复模型的状态。
二、加载模型
在PyTorch中,模型的加载使用torch.load()函数。torch.load()接受一个参数,即模型的保存路径。以下是一个示例:
# 加载模型model = torch.load('model.pth')
这会从model.pth文件中加载模型。需要注意的是,加载模型后,模型的参数会被自动加载,但模型的state_dict(即模型的参数)也会被自动添加到模型中。这意味着,在加载模型后,可以直接使用模型的forward()方法进行预测,而不需要手动设置模型的参数。
三、注意事项
map_location参数来指定加载模型的设备。例如:
# 在CPU上加载模型model = torch.load('model.pth', map_location=torch.device('cpu'))# 在GPU上加载模型model = torch.load('model.pth', map_location=torch.device('cuda'))
torch.save()和torch.load()函数中分别添加optimizer=True和map_location=torch.device('cpu')来实现。例如:
# 保存模型和优化器状态torch.save({'model': model,'optimizer': optimizer,}, 'model_and_optimizer.pth')# 加载模型和优化器状态checkpoint = torch.load('model_and_optimizer.pth', map_location=torch.device('cpu'))model = checkpoint['model']optimizer = checkpoint['optimizer']