读写 Checkpoint
更新时间:2025-01-10
bostorchconnector 支持直接对存储在 BOS 上的 checkpoint 做读写。
以 resnet 在某次 epoch 的结果写入和加载 checkpoint 为例:
from bostorchconnector import BosCheckpoint
import torchvision
import torch
# 填充 <BUCKET>、<KEY>和对应的endpoint
CHECKPOINT_URI="bos://<BUCKET>/<KEY>/"
ENDPOINT="http://bj.bcebos.com"
config = BosClientConfig(log_level=1)
checkpoint = BosCheckpoint(endpoint=ENDPOINT, bos_client_config=config)
model = torchvision.models.resnet18()
# 保存checkpoint到Bos
with checkpoint.writer(CHECKPOINT_URI + "epoch0.ckpt") as writer:
torch.save(model.state_dict(), writer)
# 从Bos读取checkpoint
with checkpoint.reader(CHECKPOINT_URI + "epoch0.ckpt") as reader:
state_dict = torch.load(reader)
model.load_state_dict(state_dict)