读写 Checkpoint
更新时间:2025-01-10
bostorchconnector 支持直接对存储在 BOS 上的 checkpoint 做读写。
以 resnet 在某次 epoch 的结果写入和加载 checkpoint 为例:
Bash
1from bostorchconnector import BosCheckpoint
2
3import torchvision
4import torch
5
6# 填充 <BUCKET>、<KEY>和对应的endpoint
7CHECKPOINT_URI="bos://<BUCKET>/<KEY>/"
8ENDPOINT="http://bj.bcebos.com"
9config = BosClientConfig(log_level=1)
10checkpoint = BosCheckpoint(endpoint=ENDPOINT, bos_client_config=config)
11
12model = torchvision.models.resnet18()
13
14# 保存checkpoint到Bos
15with checkpoint.writer(CHECKPOINT_URI + "epoch0.ckpt") as writer:
16 torch.save(model.state_dict(), writer)
17
18# 从Bos读取checkpoint
19with checkpoint.reader(CHECKPOINT_URI + "epoch0.ckpt") as reader:
20 state_dict = torch.load(reader)
21
22model.load_state_dict(state_dict)