Unsloth使用简介

简而言之用其降低模型训练时的显存。

Unsloth与HuggingFace生态兼容,可以很容易将其与transformers、peft、trl等代码库进行结合,以实现模型的SFT与DPO,仅需修改模型的加载方式即可,无需对此前的训练代码进行过多的修改。
安装方法如下:

pip install git+https://github.com/unslothai/unsloth.git

SFT

支持的模型 为Llama (Yi, TinyLlama, Qwen, Deepseek etc) 和Mistral 、Qwen架构。所以基本能都包括了。

官方示例如下:

import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel

max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number

# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/mistral-7b",
    max_seq_length=max_seq_length,
    dtype=None,  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit=True,  # Use 4bit quantization to reduce memory usage. Can be False
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

# 如果想进行lora训练则下面可选
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=16,
    lora_dropout=0,  # Dropout = 0 is currently optimized
    bias="none",  # Bias = "none" is currently optimized
    use_gradient_checkpointing=True,
    random_state=3407,
)

args = SFTConfig(
    output_dir="./output",
    max_seq_length=max_seq_length,
    dataset_text_field="text",
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
)
trainer.train()

其与huggingface集成的相当好,使用huggingface的训练脚本其他基本都不用变,只需要将模型加载时的AutoModelForCausalLM 换为FastLanguageModel。如果想peft lora的话进一步使用FastLanguageModel.get_peft_model即可。Qlora在此基础上将load_in_4bit==True设置即可。

DPO

与上述基本一样。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值