Unsloth助力Llama3-Chinese-8B-Instruct中文大模型微调实战

作者:c4t2024.08.14 13:50浏览量:10

简介:本文介绍了如何利用Unsloth工具对Llama3-Chinese-8B-Instruct中文开源大模型进行微调,通过实践经验和步骤详解,帮助读者提升模型在特定任务上的表现。

Unsloth助力Llama3-Chinese-8B-Instruct中文大模型微调实战

在当今自然语言处理(NLP)领域,大模型的微调已成为提升模型在特定任务上表现的重要手段。本文将详细介绍如何使用Unsloth这一开源大模型训练加速项目,对Llama3-Chinese-8B-Instruct中文开源大模型进行微调,以期为读者提供可操作的建议和解决问题的方法。

一、引言

Llama3-Chinese-8B-Instruct是基于Meta Llama-3的中文开源大模型,它在原版Llama-3的基础上,通过大规模中文数据的增量预训练和精选指令数据的精调,显著提升了中文基础语义和指令理解能力。而Unsloth则是一个专注于大模型训练加速的开源项目,能够显著提升训练速度并减少显存占用。

二、Unsloth与Llama3-Chinese-8B-Instruct的结合

1. Unsloth简介

Unsloth是一个开源的大模型训练加速项目,其特点包括:

  • 提升训练速度:最高可达5倍,Unsloth Pro更是高达30倍。
  • 减少显存占用:最大可减少80%的显存占用。
  • 广泛兼容性:与HuggingFace生态兼容,支持多种主流GPU设备。
2. Llama3-Chinese-8B-Instruct的优势

Llama3-Chinese-8B-Instruct模型在原版Llama-3的基础上进行了多项优化,包括:

  • 大规模训练数据:使用了约15万亿个标记(tokens)进行训练。
  • 中文增量预训练:针对中文语境进行了增量预训练。
  • 精选指令精调:使用精选指令数据进行精调,提升指令理解能力。

三、微调步骤详解

1. 环境设置

首先,需要创建一个Python虚拟环境并安装必要的依赖项。推荐使用Python 3.10版本,并安装Unsloth及相关库:

  1. conda create --name unsloth_env python=3.10
  2. conda activate unsloth_env
  3. !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
  4. !pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes
  5. !pip install modelscope
2. 下载预训练模型

使用modelscope工具下载Llama3-Chinese-8B-Instruct预训练模型:

  1. from modelscope import snapshot_download
  2. model_dir = snapshot_download('FlagAlpha/Llama3-Chinese-8B-Instruct', cache_dir="/root/models")
3. 加载模型与Tokenizer

利用Unsloth提供的FastLanguageModel加载模型与Tokenizer:

  1. from unsloth import FastLanguageModel
  2. import torch
  3. model, tokenizer = FastLanguageModel.from_pretrained(
  4. model_name="/root/models/Llama3-Chinese-8B-Instruct",
  5. max_seq_length=2048,
  6. dtype=torch.float16,
  7. load_in_4bit=True
  8. )
4. 设置LoRA训练参数

LoRA(Low-Rank Adaptation)是一种低阶适配器技术,可大幅减少模型微调时的参数更新量。设置LoRA参数如下:

  1. model = FastLanguageModel.get_peft_model(
  2. model,
  3. r=16,
  4. target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
  5. lora_alpha=16,
  6. lora_dropout=0,
  7. bias="none",
  8. use_gradient_checkpointing="unsloth",
  9. random_state=3407,
  10. use_rslora=False
  11. )