模型训练Trainer使用说明
更新时间:2024-10-21
功能介绍
千帆ModelBuilder Python SDK支持调用Trainer相关API,支持对数据集进行自定义训练。本文使用千帆ModelBuilder SFT语言大模型为例介绍。
注意事项
- 调用本文API,需使用安全认证AK/SK鉴权,调用流程及鉴权介绍详见SDK安装及使用流程。
-
本文涉及以下函数列表
- 加载数据集Dataset.load()
- 创建Trainer LLMFinetune()
- 配置自定义训练参数TrainConfig()
- 查询训练参数ModelInfoMapping()
- 查询训练参数默认值DefaultTrainConfigMapping()
- 启动训练任务run()
- 重启训练任务resume()
- 日志打印enable_log()
调用流程简介
(1)打印日志
注意:如果无需打印日志,可跳过此步骤。
如果需打印过程日志,通过调用enable_log()实现。
(2)准备并加载数据集
注意:数据集要求,必须是有标注的非排序对话数据集。
加载千帆ModelBuilder上的数据集,通过调用Dataset.load()实现,详见参数说明。
(3)创建Trainer
调用LLMFInetune()创建Trainer对象,此步骤也会初步校验TrainConfig参数,如果有不符合的字段会打印warning日志。
注意:如果需要自定义训练参数,可通过调用TrainConfig()实现,详见参数说明。
(4)启动训练任务
通过调用run()实现。
(5)重启训练任务
注意:如果无需重启任务,可跳过此步骤。
如果突发断电或者任务停止,可以使用resume()重启任务。
调用示例
未自定义训练参数
import os
import qianfan
# 使用安全认证AK/SK鉴权,通过环境变量方式初始化;替换下列示例中参数,安全认证Access Key替换your_iam_ak,Secret Key替换your_iam_sk
os.environ["QIANFAN_ACCESS_KEY"] = "your_iam_ak"
os.environ["QIANFAN_SECRET_KEY"] = "your_iam_sk"
# 如果希望打印过程日志,通过调用enable_log(logging.INFO)启用打印日志功能
#from qianfan.utils import enable_log
#import logging
#enable_log(logging.INFO) # 设置打印日志的最低级别
from qianfan.dataset import Dataset
from qianfan.trainer import LLMFinetune
# 加载千帆ModelBuilder数据集,is_download_to_local=False表示不下载数据集到本地,而是直接使用
ds: Dataset = Dataset.load(qianfan_dataset_id="your_dataset_id", is_download_to_local=False)
# 新建trainer LLMFinetune,需最少传入train_type和dataset
# 注意fine-tune任务需要指定的数据集类型要求为有标注的非排序对话数据集。
trainer = LLMFinetune(
train_type="ERNIE-xx",
dataset=ds,
)
trainer.run()
# 如果突发断电或者任务停止,可以使用resume函数重启任务
# trainer.resume()
自定义训练参数
import os
import qianfan
# 使用安全认证AK/SK鉴权,通过环境变量方式初始化;替换下列示例中参数,安全认证Access Key替换your_iam_ak,Secret Key替换your_iam_sk
os.environ["QIANFAN_ACCESS_KEY"] = "your_iam_ak"
os.environ["QIANFAN_SECRET_KEY"] = "your_iam_sk"
# 如果希望打印过程日志,通过调用enable_log(logging.INFO)启用打印日志功能
#from qianfan.utils import enable_log
#import logging
#enable_log(logging.INFO) # 设置打印日志的最低级别
from qianfan.dataset import Dataset
from qianfan.trainer import LLMFinetune
from qianfan.trainer.configs import TrainConfig
# 加载千帆ModelBuilder的数据集。
# qianfan_dataset_id是数据集id,类型要求为有标注的非排序对话数据集is_download_to_local=False表示不下载数据集到本地,直接使用
ds: Dataset = Dataset.load(qianfan_dataset_id="your_dataset_id", is_download_to_local=False)
# 发起训练任务。以基础模型ERNIE-xx为例,需要指定的数据集类型要求为有标注的非排序对话数据集。
trainer = LLMFinetune(
train_type="ERNIE-xx",
dataset=ds,
peft_type="LoRA",
# 自定义训练参数
train_config=TrainConfig(
epochs=1, # 迭代轮次(Epoch),控制训练过程中的迭代轮数。
# batch_size=32, # 批处理大小(BatchSize)表示在每次训练迭代中使用的样本数。较大的批处理大小可以加速训练.部分模型可能无需填写该字段
learning_rate=0.00004, # 学习率(LearningRate)是在梯度下降的过程中更新权重时的超参数,过高会导致模型难以收敛,过低则会导致模型收敛速度过慢,
)
)
trainer.run()
# 如果突发断电或者任务停止,可以使用resume函数重启任务
# trainer.resume()
函数列表
模型训练Trainer需使用的部分函数如下。
- 加载数据集
- 创建Trainer
- 配置自定义训练参数
- 查询训练参数
- 查询训练参数默认值
加载数据集
加载千帆ModelBuilder的数据集。
示例
Dataset.load(qianfan_dataset_id="your_dataset_id", is_download_to_local=False)
请求参数
名称 | 类型 | 必填 | 描述 |
---|---|---|---|
qianfan_dataset_id | string | 是 | 要导入的数据集版本ID,说明: (1)可以通过以下任一方式获取该字段值: · 方式一,通过调用创建数据集接口,返回的datasetId字段获取。 · 方式二,在千帆ModelBuilder控制台-数据集管理列表页面,点击详情,在版本信息页查看,如下图所示: (2)数据集类型,要求为有标注的非排序对话数据集 |
is_download_to_local | bool | 是 | 是否下载数据到本地。 True:下载数据集到本地 False:不下载数据集到本地,而是直接使用 |
创建Trainer
调用LLMFInetune()创建Trainer对象,此步骤也会初步校验TrainConfig参数,如果有不符合的字段会打印warning日志。
示例
LLMFinetune(
train_type="ERNIE-xx"
)
请求参数
参数名 | 数据类型 | 必填 | 描述 |
---|---|---|---|
train_type | String | 是 | 模型版本,示例:ERNIE-Lite-8K-0922,可以通过以下方法获取具体值: 在千帆ModelBuilder控制台-模型调优-SFT页面-点击创建训练任务,选择基础模型,查看模型版本,如下图所示: |
dataset | Optional[Any] | 否 | 一个数据集实例。说明: 数据集dataset和此参数,至少填写一个 |
train_config | Union[TrainConfig, string] | 否 | 用于微调训练参数的TrainConfig。说明: 如果不填写此参数,将使用不同模型的默认参数。 |
deploy_config | DeployConfig | 否 | 用于模型服务部署参数的DeployConfig。说明: 如果需要部署服务,此参数必填。 |
event_handler | EventHandler | 否 | 用于接收训练过程中事件处理的EventHandler实例。 |
base_model | String | 否 | 基础模型,示例:ChatGLM2 |
eval_dataset | Optional[Any] | 否 | 可选的评价数据集 |
evaluators | List[Evaluator] | 否 | 用于评估的评估器列表 |
dataset_bos_path | String | 否 | 训练用的 bos 路径,说明: 数据集dataset和此参数,至少填写一个 |
配置自定义训练参数
示例
TrainConfig(
epochs=1,
batch_size=32,
learning_rate=0.00004,
)
请求参数
模型不同,训练配置使用的参数不同。可以通过以下任一方式查询请求参数:
- 方式一:通过提供的请求参数列表
- 方式二:通过调用查询训练参数ModelInfoMapping(),获取参数列表
- 方式一:请求参数列表
说明:下列表格中的模型支持情况,请参考模型支持情况。
名称 | 类型 | 必填 | 描述 |
---|---|---|---|
epoch | int | 否 | 迭代轮次,说明:该字段取值详情参考模型支持情况 |
learningRate | float | 否 | 学习率,说明:说明:该字段取值详情参考模型支持情况 |
batchSize | int | 否 | 批处理大小,说明:该字段取值更多详情参考模型支持情况 |
maxSeqLen | int | 否 | 序列长度,说明:该字段取值详情参考模型支持情况 |
loggingSteps | int | 否 | 保存日志间隔,说明: (1)当为以下情况,该字段必填 · model为ERNIE-Speed-8K,且trainMode为SFT · model为ERNIE-Lite-8K-0922,且trainMode为SFT · model为ERNIE-Lite-8K-0308,且trainMode为SFT · model为ERNIE-Tiny-8K,且trainMode为SFT (2)取值范围[1, 100],默认值为1 |
warmupRatio | float | 否 | 预热比例,说明:该字段取值详情参考模型支持情况 |
weightDecay | float | 否 | 正则化系数,说明:该字段取值详情参考模型支持情况 |
loraRank | int | 否 | LoRA 策略中的秩,说明:该字段取值详情参考模型支持情况 |
loraAlpha | int | 否 | 说明:说明:该字段取值更多详情参考模型支持情况 |
loraAllLinear | string | 否 | LoRA 所有线性层,说明:该字段取值详情参考模型支持情况 |
loraTargetModules | string[] | 否 | 说明:该字段取值详情参考模型支持情况 |
loraDropout | float | 否 | 说明:该字段取值更多详情参考模型支持情况 |
schedulerName | string | 否 | 说明:该字段取值详情参考模型支持情况 |
Packing | bool | 否 | 可选值:true 或 false,默认值false,说明:该字段取值详情参考模型支持情况 |
extras | Dict[str, Any] | {} | 其他参数字典,保留值 |
- 方式二:通过接口查询参数
详见查询训练参数介绍。
查询训练参数
请求示例
from qianfan.trainer.configs import ModelInfoMapping
print(ModelInfoMapping['ERNIE-xx'])
返回示例
short_name='xxx'
base_model_type='ERNIE-Lite-8K-0922'
support_peft_types=[<PeftType.ALL: 'ALL'>, <PeftType.LoRA: 'LoRA'>]
common_params_limit=TrainLimit(
batch_size_limit=(1, 4),
max_seq_len_options=[4096, 8192], epoch_limit=(1, 50),
learning_rate_limit=(2e-07, 0.0002),
log_steps_limit=None,
warmup_ratio_limit=None,
weight_decay_limit=None,
lora_rank_options=None,
lora_alpha_options=None,
lora_dropout_limit=None,
scheduler_name_options=None
)
specific_peft_types_params_limit={
'ALL': TrainLimit(
batch_size_limit=None,
max_seq_len_options=None,
epoch_limit=None,
learning_rate_limit=(1e-05, 4e-05),
log_steps_limit=None,
warmup_ratio_limit=None,
weight_decay_limit=None,
lora_rank_options=None,
lora_alpha_options=None,
lora_dropout_limit=None,
scheduler_name_options=None
),
'LoRA': TrainLimit(
batch_size_limit=None,
max_seq_len_options=None,
epoch_limit=None,
learning_rate_limit=(3e-05, 0.001),
log_steps_limit=None,
warmup_ratio_limit=None,
weight_decay_limit=None,
lora_rank_options=None,
lora_alpha_options=None,
lora_dropout_limit=None,
scheduler_name_options=None
)
}
请求参数
名称 | 类型 | 描述 |
---|---|---|
train_type | string | 模型版本,可以通过以下方法获取具体值: 在千帆ModelBuilder控制台-模型调优-SFT页面-点击创建训练任务,选择基础模型,查看模型版本,如下图所示: |
返回参数
说明:下列表格中的模型支持情况,请参考模型支持情况。
名称 | 类型 | 必填 | 描述 |
---|---|---|---|
epoch | int | 否 | 迭代轮次,说明:该字段取值详情参考模型支持情况 |
learningRate | float | 否 | 学习率,说明:说明:该字段取值详情参考模型支持情况 |
batchSize | int | 否 | 批处理大小,说明:该字段取值更多详情参考模型支持情况 |
maxSeqLen | int | 否 | 序列长度,说明:该字段取值详情参考模型支持情况 |
loggingSteps | int | 否 | 保存日志间隔,说明: (1)当为以下情况,该字段必填 · model为ERNIE-Speed-8K,且trainMode为SFT · model为ERNIE-Lite-8K-0922,且trainMode为SFT · model为ERNIE-Lite-8K-0308,且trainMode为SFT · model为ERNIE-Tiny-8K,且trainMode为SFT (2)取值范围[1, 100],默认值为1 |
warmupRatio | float | 否 | 预热比例,说明:该字段取值详情参考模型支持情况 |
weightDecay | float | 否 | 正则化系数,说明:该字段取值详情参考模型支持情况 |
loraRank | int | 否 | LoRA 策略中的秩,说明:该字段取值详情参考模型支持情况 |
loraAlpha | int | 否 | 说明:说明:该字段取值更多详情参考模型支持情况 |
loraAllLinear | string | 否 | LoRA 所有线性层,说明:该字段取值详情参考模型支持情况 |
loraTargetModules | string[] | 否 | 说明:该字段取值详情参考模型支持情况 |
loraDropout | float | 否 | 说明:该字段取值更多详情参考模型支持情况 |
schedulerName | string | 否 | 说明:该字段取值详情参考模型支持情况 |
Packing | bool | 否 | 可选值:true 或 false,默认值false,说明:该字段取值详情参考模型支持情况 |
extras | Dict[str, Any] | {} | 其他参数字典,保留值 |
查询训练参数默认值
请求示例
from qianfan.trainer.configs import DefaultTrainConfigMapping
print(DefaultTrainConfigMapping['ERNIE-xx'])
返回示例
epoch=1
batch_size=None
learning_rate=3e-05
max_seq_len=4096
peft_type='LoRA'
trainset_rate=20
logging_steps=None
warmup_ratio=None
weight_decay=None
lora_rank=None
lora_all_linear=None
scheduler_name=None
lora_alpha=None
lora_dropout=None
extras={}
请求参数
名称 | 类型 | 描述 |
---|---|---|
train_type | string | 模型版本,示例:BLOOMZ_7B,可以通过以下方法获取具体值: 在千帆ModelBuilder控制台-模型调优-SFT页面-点击创建训练任务,选择基础模型,查看模型版本,如下图所示: |
返回参数
说明:下列表格中的模型支持情况,请参考模型支持情况。
名称 | 类型 | 必填 | 描述 |
---|---|---|---|
epoch | int | 否 | 迭代轮次,说明:该字段取值详情参考模型支持情况 |
learningRate | float | 否 | 学习率,说明:说明:该字段取值详情参考模型支持情况 |
batchSize | int | 否 | 批处理大小,说明:该字段取值更多详情参考模型支持情况 |
maxSeqLen | int | 否 | 序列长度,说明:该字段取值详情参考模型支持情况 |
loggingSteps | int | 否 | 保存日志间隔,说明: (1)当为以下情况,该字段必填 · model为ERNIE-Speed-8K,且trainMode为SFT · model为ERNIE-Lite-8K-0922,且trainMode为SFT · model为ERNIE-Lite-8K-0308,且trainMode为SFT · model为ERNIE-Tiny-8K,且trainMode为SFT (2)取值范围[1, 100],默认值为1 |
warmupRatio | float | 否 | 预热比例,说明:该字段取值详情参考模型支持情况 |
weightDecay | float | 否 | 正则化系数,说明:该字段取值详情参考模型支持情况 |
loraRank | int | 否 | LoRA 策略中的秩,说明:该字段取值详情参考模型支持情况 |
loraAlpha | int | 否 | 说明:说明:该字段取值更多详情参考模型支持情况 |
loraAllLinear | string | 否 | LoRA 所有线性层,说明:该字段取值详情参考模型支持情况 |
loraTargetModules | string[] | 否 | 说明:该字段取值详情参考模型支持情况 |
loraDropout | float | 否 | 说明:该字段取值更多详情参考模型支持情况 |
schedulerName | string | 否 | 说明:该字段取值详情参考模型支持情况 |
Packing | bool | 否 | 可选值:true 或 false,默认值false,说明:该字段取值详情参考模型支持情况 |
extras | Dict[str, Any] | {} | 其他参数字典,保留值 |