简介:PyTorch 序列化
PyTorch 序列化
在深度学习中,模型训练和优化后的模型保存和加载是至关重要的。PyTorch 提供了一种方便的序列化机制,允许我们将模型保存到磁盘上,并在需要时加载回来。这种序列化机制对于模型部署、迁移学习和微调等场景非常有用。
PyTorch 的序列化机制基于 Python 的 pickle 模块,可以将几乎任何 Python 对象保存到磁盘上。然而,对于深度学习模型,我们通常希望以一种更紧凑和高效的方式进行序列化和反序列化。为此,PyTorch 提供了 torch.save() 和 torch.load() 函数,专门用于保存和加载 PyTorch 模型和张量。
torch.save() 函数可以将一个 PyTorch 模型或张量保存到一个二进制文件中。该函数接受两个参数:要保存的对象和保存的文件名。它还支持将其他 Python 对象作为关键字参数传递,这些对象将被序列化并保存到文件中。
例如,以下代码演示了如何使用 torch.save() 函数保存一个简单的 PyTorch 模型:
import torchimport torch.nn as nn# 定义一个简单的 PyTorch 模型model = nn.Sequential(nn.Linear(10, 5),nn.ReLU(),nn.Linear(5, 2))# 将模型保存到磁盘上torch.save(model.state_dict(), "model.pth")
在上面的代码中,我们首先定义了一个简单的 PyTorch 模型,然后使用 model.state_dict() 方法获取模型的参数,并将其保存到名为 “model.pth” 的文件中。torch.save() 函数将模型的参数保存为一个字典,其中键是参数的名称,值是参数的值。
要加载一个 PyTorch 模型,可以使用 torch.load() 函数。该函数接受一个参数,即要加载的文件名,并返回一个字典,其中包含模型的参数。然后,可以使用模型的 load_state_dict() 方法将参数加载到模型中。
例如,以下代码演示了如何使用 torch.load() 函数加载一个保存的模型:
import torchimport torch.nn as nn# 从磁盘上加载模型的参数model_params = torch.load("model.pth")# 定义一个与原始模型结构相同的模型model = nn.Sequential(nn.Linear(10, 5),nn.ReLU(),nn.Linear(5, 2))# 将参数加载到模型中model.load_state_dict(model_params)
在上面的代码中,我们首先使用 torch.load() 函数加载了保存的模型参数,然后定义了一个与原始模型结构相同的模型。最后,我们使用 load_state_dict() 方法将加载的参数加载到模型中。