简而言之用其降低模型训练时的显存。
Unsloth与HuggingFace生态兼容,可以很容易将其与transformers、peft、trl等代码库进行结合,以实现模型的SFT与DPO,仅需修改模型的加载方式即可,无需对此前的训练代码进行过多的修改。
安装方法如下:
pip install git+https://github.com/unslothai/unsloth.git
SFT
支持的模型 为Llama (Yi, TinyLlama, Qwen, Deepseek etc) 和Mistral 、Qwen架构。所以基本能都包括了。
官方示例如下:
import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_seq_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# 如果想进行lora训练则下面可选
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)
args = SFTConfig(
output_dir="./output",
max_seq_length=max_seq_length,
dataset_text_field="text",
)
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
)
trainer.train()
其与huggingface集成的相当好,使用huggingface的训练脚本其他基本都不用变,只需要将模型加载时的AutoModelForCausalLM 换为FastLanguageModel。如果想peft lora的话进一步使用FastLanguageModel.get_peft_model即可。Qlora在此基础上将load_in_4bit==True
设置即可。
DPO
与上述基本一样。