logo
话题头图

【千帆SDK】使用文生图数据集进行模型微调生成Pokemon风格图片

💡学习前小提示
请大家点击链接并加🌟:https://github.com/baidubce/bce-qianfan-sdk
trainer 发起 finetune 中,我们已经学习了如何使用trainer+dataset发起文生文微调任务,同时也体验了模型评估,批量推理,服务部署等流程;除了纯文本的生成模型外,千帆平台也提供了针对文心一格,以及开源的StableDiffusion模型的训练微调。
本例将基于qianfan==0.3.6.1展示通过Dataset加载本地数据集,并上传到千帆平台,基于Stable-Diffusion进行fine-tune,以实现pokemon风格的图片生成能力。
  
  
  
  
  
  
! pip install "qianfan[dataset_base]" -U
! pip install datasets==2.14.6 # huggingface datasets
! pip install fsspec==2023.9.2 # fix load_dataset error
  
  
  
  
  
  
import qianfan
qianfan.__version__

前置准备

  • 初始化千帆安全认证AK、SK
  
  
  
  
  
  
import os
os.environ["QIANFAN_ACCESS_KEY"] = "your_ak"
os.environ["QIANFAN_SECRET_KEY"] = "your_sk"

导入依赖

  • qianfan.trainer.consts trainer使用中所用到的常量
  • qianfan.resources.console.consts api层面定义的字段常量
  • qianfan.trainer.configs trainer使用所需要的config配置数据类
  • qianfan.resources.QfMessages 用于组装qianfan.ChatCompletion的输入messages
  • qianfan.trainer.finetune.Finetune 大语言模型fine-tune任务Trainer实现
  • qianfan.dataset.Dataset 千帆dataset类,用于管理千帆平台、本地、第三方数据集的导入导出,数据清洗等操作
  
  
  
  
  
  
from qianfan.trainer.consts import ActionState
from qianfan.model.consts import ServiceType
from qianfan.resources.console import consts as console_consts
from qianfan.trainer.configs import TrainConfig
from qianfan.model.configs import DeployConfig
from qianfan.resources import QfMessages
from qianfan.trainer.finetune import Finetune
from qianfan.dataset import Dataset
from qianfan.utils import enable_log
import logging
enable_log(logging.INFO)
我们此次选用huggingface的开源数据集用于生成pokemon风格的图片
  
  
  
  
  
  
# 从huggingface 导入数据集:
import datasets
dataset = datasets.load_dataset("svjack/pokemon-blip-captions-en-zh", split='train')
  
  
  
  
  
  
dataset.column_names
# 输出:['image', 'en_text', 'zh_text']
将huggingface上的文生图数据集,增加指令,并转存成本地的数据集目录
  
  
  
  
  
  
import os
import json
pokemon_style_instruction = "pokemon,"
save_ds_dir = "./pokemon_ds"
if not os.path.exists(save_ds_dir):
os.mkdir(save_ds_dir)
for i, v in enumerate(dataset):
v["image"].save(f"{save_ds_dir}/{i}.jpg")
with open(f"{save_ds_dir}/{i}.json", "w") as f:
json.dump({"prompt": f'{pokemon_style_instruction} {v["en_text"]}'}, f)

数据集加载

千帆SDK提供了数据集实现帮助我们可以快速的加载本地的数据集到内存,并通过设定DataSource数据源以保存至本地和千帆平台。
  
  
  
  
  
  
from qianfan.dataset import Dataset
from qianfan.dataset.data_source import FileDataSource
from qianfan.dataset.data_source.base import FormatType
file_data_source = FileDataSource(path=save_ds_dir, file_format=FormatType.Text2Image)
ds = Dataset.load(file_data_source)
print(ds.list(0))
从本地数据集上传到BOS
  
  
  
  
  
  
# 保存到千帆平台
from qianfan.dataset.data_source import QianfanDataSource
from qianfan.resources.console import consts as console_consts
bos_bucket_name = "sdk-test"
bos_bucket_file_path = "/sdk_ds/"
qianfan_dataset_name = "random_sdk_train_t2i"
# 创建千帆数据集,并上传保存
qianfan_data_source = QianfanDataSource.create_bare_dataset(
name=qianfan_dataset_name,
template_type=console_consts.DataTemplateType.Text2Image,
storage_type=console_consts.DataStorageType.PrivateBos,
storage_id=bos_bucket_name,
storage_path=bos_bucket_file_path,
)
ds = ds.save(qianfan_data_source)

发起图生文训练

这里我们选用Stable-Diffusion-XL-Base-1.0作为基础模型,
  
  
  
  
  
  
from qianfan.trainer.consts import PeftType
trainer = Finetune(
train_type="Stable-Diffusion-XL-Base-1.0",
train_config=TrainConfig(
peft_type=PeftType.LoRA,
batch_size=8,
epoch=20,
learning_rate=0.00005,
),
dataset=ds,
)

运行任务

同步运行trainer,训练直到模型发布完成
  
  
  
  
  
  
trainer.run()
获取finetune任务输出:
  
  
  
  
  
  
trainer.output
使用sdk发起部署流程,这一步需要到前端控制台进行支付才能完成:
  
  
  
  
  
  
#-# cell_skip
from qianfan.model import Service, Model
from qianfan.model.consts import ServiceType
from qianfan.resources.console.consts import DeployPoolType
# 从训练结果中获取模型对象
m: Model = trainer.output["model"]
sft_svc: Service = m.deploy(DeployConfig(
name="random_t2i_sdk1",
endpoint_prefix="sdpoke1",
replicas=1, # 副本数, 与qps强绑定
pool_type=DeployPoolType.PrivateResource, # 私有资源池
service_type=ServiceType.Text2Image,
))
使用Finetune之后的模型服务调用:
  
  
  
  
  
  
#-# cell_skip
from qianfan.resource import Text2Image
### 使用Model & Service调用模型
problem="pokemon, a blue monkey with a hat"
#获取服务对象,即ChatCompletion等类型的对象
t2i: qianfan.Text2Image = sft_svc.get_res()
from PIL import Image
import io
resp = t2i.do(prompt=problem, with_decode="base64")
img_data = resp["body"]["data"][0]["image"]
img = Image.open(io.BytesIO(img_data))
display(img)
评论
用户头像