Base模型与LoRA权重合并工具
更新时间:2024-09-08
目前仅支持Merge Huggingface权重格式的数据,如果需要对Megatron格式数据进行操作,需要做格式转换。
参数 | 参数类型 | 参数英文 | 说明 |
---|---|---|---|
Base模型权重路径 | str | base_model_path | 模型的Base部分权重路径 |
LoRA权重路径 | str | lora_path | 模型的LoRA部分权重路径 |
合并权重输出路径 | str | output_path | 合并结果的输出路径 |
Python
1import fire
2
3import torch
4
5from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
6from peft import PeftModel
7
8def lora(base_model_path: str, lora_path: str, output_path: str):
9 print(f'base_model_path: {base_model_path}')
10 print(f'lora_path: {lora_path}')
11 print(f'output_path: {output_path}')
12
13 # merge and save model
14 config = AutoConfig.from_pretrained(base_model_path, trust_remote_code=True)
15 model = AutoModelForCausalLM.from_pretrained(base_model_path,
16 config=config,
17 device_map="auto",
18 torch_dtype=torch.float16,
19 trust_remote_code=True)
20 model = PeftModel.from_pretrained(model, lora_path, device_map="auto")
21 model = model.eval()
22
23 model = model.merge_and_unload()
24 model.save_pretrained(output_path)
25
26 # save_tokenizer
27 tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
28 tokenizer.save_pretrained(output_path)
29
30if __name__ == '__main__':
31 fire.Fire()