简介:PyTorch加载PT模型:如何正确地加载训练好的模型
PyTorch加载PT模型:如何正确地加载训练好的模型
在PyTorch中,加载已经训练好的模型通常是通过加载保存的模型参数进行的。这是一个非常实用的功能,因为它允许我们复用已经训练过的模型,或者在分布式环境中分布加载模型。这篇文章将介绍PyTorch加载PT模型的基本步骤。
torch.save(model.state_dict(), PATH)保存到磁盘上。其中model是你的PyTorch模型实例,PATH是模型参数将要保存到的路径。state_dict()方法返回一个包含模型所有参数的字典。
import torchimport torchvision.models as models# 实例化一个预训练的ResNet18模型model = models.resnet18(pretrained=True)# 保存模型参数到磁盘torch.save(model.state_dict(), 'resnet18.pt')
load_state_dict()方法加载参数。注意,在加载模型参数后,你需要将模型设置为评估模式(
import torchimport torchvision.models as models# 实例化一个ResNet18模型(不使用预训练参数)model = models.resnet18()# 加载保存的模型参数model.load_state_dict(torch.load('resnet18.pt'))# 确保模型处于评估模式(使用训练好的权重)model.eval()
model.eval()),以确保模型的训练掩码(dropout,batchnorm等)被正确地设置为评估模式。在训练模式下,这些层的行为会有所不同。注意:使用这种方式加载模型进行预测时,不需要将整个模型加载到内存中。只有当模型在计算图中被引用时,PyTorch才会实际加载和计算相应的部分。因此,如果你的模型很大并且输入/输出数据也很大,那么这种方式是非常有用的。
# 假设你有一个4D张量`input`作为输入# input = torch.rand(1, 3, 224, 224) # 一个随机生成的示例输入with torch.no_grad(): # 关闭梯度计算以加速预测output = model(input)# 输出的形状通常与输入相同,除了最后一维,它通常表示预测的类别数# 例如,对于CIFAR10数据集,输出形状为[batch_size, 10]
torch.nn.DataParallel类或者更高级的模型/分布式训练库,如torch.nn.parallel.DistributedDataParallel等。