读写 Checkpoint
更新时间:2026-03-05
bostorchconnector 支持直接对存储在 BOS 上的 checkpoint 做读写。
以 resnet 在某次 epoch 的结果写入和加载 checkpoint 为例:
Bash
1from bostorchconnector import BosCheckpoint
2
3import torchvision
4import torch
5
6# 填充 <BUCKET>、<KEY>和对应的endpoint
7CHECKPOINT_URI="bos://<BUCKET>/<KEY>/"
8ENDPOINT="http://bj.bcebos.com"
9config = BosClientConfig(log_level=1)
10checkpoint = BosCheckpoint(endpoint=ENDPOINT, bos_client_config=config)
11
12model = torchvision.models.resnet18()
13
14# 保存checkpoint到Bos
15with checkpoint.writer(CHECKPOINT_URI + "epoch0.ckpt") as writer:
16 torch.save(model.state_dict(), writer)
17
18# 从Bos读取checkpoint
19with checkpoint.reader(CHECKPOINT_URI + "epoch0.ckpt") as reader:
20 state_dict = torch.load(reader)
21
22model.load_state_dict(state_dict)
分布式Checkpoint(DCP)
bostorchconnector 提供了对 PyTorch 分布式 Checkpoint 的支持,包括:
- BosStorageWriter:实现了 PyTorch 的
StorageWriter接口。 - BosStorageReader:实现了 PyTorch 的
StorageReader接口。 - BosFileSystem:实现了 PyTorch 的
FileSystemBase接口。
这些工具实现了 Bos 与 PyTorch 分布式 Checkpoint 的无缝集成,支持高效存储和读取分布式模型 Checkpoint。
前置条件与安装
需要 PyTorch 2.3 或更新版本。安装时需要指定 dcp 额外依赖:
Shell
1pip install bostorchconnector[dcp]
示例
Python
1from bostorchconnector.dcp import BosStorageWriter, BosStorageReader
2
3import torchvision
4import torch.distributed.checkpoint as DCP
5
6# 配置
7CHECKPOINT_URI = "bos://<BUCKET>/<KEY>/"
8ENDPOINT = "http://bj.bcebos.com"
9
10model = torchvision.models.resnet18()
11
12# 自定义配置参数:credentials_path、log_level、log_path、part_size、pool_threads_num、max_attempts
13config = BosClientConfig(part_size = 16 * 1024 * 1024, pool_threads_num = 64)
14
15# 保存分布式 Checkpoint 到 Bos
16bos_storage_writer = BosStorageWriter(
17 endpoint=ENDPOINT,
18 path=CHECKPOINT_URI,
19 bos_client_config=config, # 可选
20 thread_count=4, # 可选,写入时使用的 IO 线程数
21 overwrite=True,
22)
23DCP.save(
24 state_dict=model.state_dict(),
25 storage_writer=bos_storage_writer,
26)
27
28# 从 Bos 加载分布式 Checkpoint
29model = torchvision.models.resnet18()
30model_state_dict = model.state_dict()
31bos_storage_reader = BosStorageReader(
32 endpoint=ENDPOINT,
33 path=CHECKPOINT_URI,
34 bos_client_config=config, # 可选
35)
36DCP.load(
37 state_dict=model_state_dict,
38 storage_reader=bos_storage_reader,
39)
40model.load_state_dict(model_state_dict)
PyTorch Lightning
bostorchconnector 包含了对 PyTorch Lightning 的集成,提供了 BosLightningCheckpoint,它实现了 Lightning 的 CheckpointIO 接口。用户可以借此在 PyTorch Lightning 中使用 Bos 进行 Checkpoint 的读写。
安装
Shell
1pip install bostorchconnector[lightning]
示例
Python
1from lightning import Trainer
2from bostorchconnector.lightning import BosLightningCheckpoint
3
4from fsspec.registry import register_implementation
5import bosfs
6
7# ...
8
9CHECKPOINT_URI = "bos://<BUCKET>/<KEY>/"
10ENDPOINT = "http://bj.bcebos.com"
11
12save_only_latest = True
13
14register_implementation("bos", bosfs.BOSFileSystem)
15
16dataset = WikiText2()
17dataloader = DataLoader(dataset, num_workers=2)
18
19model = LightningTransformer(vocab_size=dataset.vocab_size)
20bos_lightning_checkpoint = BosLightningCheckpoint(endpoint=ENDPOINT)
21
22checkpoint_callback = ModelCheckpoint(
23 dirpath=CHECKPOINT_URI,
24 save_top_k=1 if save_only_latest else -1,
25 every_n_train_steps=1,
26 filename="checkpoint-{epoch:02d}-{step:02d}",
27 enable_version_counter=True,
28)
29
30trainer = Trainer(
31 plugins=[bos_lightning_checkpoint],
32 callbacks=[checkpoint_callback],
33 min_epochs=4,
34 max_epochs=5,
35 max_steps=3,
36)
37trainer.fit(model, dataloader)
38
39# read
40r_trainer = Trainer(
41 plugins=[bos_lightning_checkpoint],
42 min_epochs=4,
43 max_epochs=5,
44 max_steps=3,
45)
46# Load the checkpoint in `ckpt_path` before training
47r_trainer.fit(model, dataloader, ckpt_path=CHECKPOINT_URI + "checkpoint-epoch=00-step=03.ckpt")
