简介:PyTorch打印State Dictionary
PyTorch打印State Dictionary
在PyTorch中,模型的参数通常存储在state_dict中。state_dict是一个字典,其中键是参数的名称,值是相应的参数值。了解如何打印和查看state_dict对于调试和优化模型非常重要。
以下是如何在PyTorch中打印state_dict的步骤:
pip install torch
import torch
model = torch.nn.DataParallel(torch.nn.parallel.DistributedDataParallel(torch.nn.Sequential(torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),torch.nn.BatchNorm2d(64),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),torch.nn.BatchNorm2d(64),torch.nn.Linear(9216, 1000),torch.nn.ReLU(),torch.nn.Linear(1000, 1000),torch.nn.ReLU(),torch.nn.Linear(1000, 1000),torch.nn.ReLU())))
state_dict:
state_dict = model.state_dict()
state_dict:这将打印出每个参数的名称和值。注意,
for name, param in state_dict.items():print('Parameter Name:', name)print('Parameter Value:', param)
state_dict中的参数是Tensor对象,因此如果你想查看它们的形状,可以使用.shape属性:
for name, param in state_dict.items():print('Parameter Name:', name)print('Parameter Value:', param)print('Parameter Shape:', param.shape)