Qwen模型使用trl仓库的ktotrainer训练实战


前言

本文使用trl仓库的ktotrainer实现对Qwen模型的训练,使模型对齐人类偏好。之前在博客中(链接)已经大概讲解了KTO的算法思想,本文主要为实战部分。
本文首先使用trl的官方示例代码跑通整个流程,之后对数据集处理部分作了修改,修改后可以使用PPO、DPO相同的数据集进行kto训练。
本文的训练数据集为:comparison_gpt4_data_zh
模型为:Qwen-7B-Chat

文末有github仓库,仓库中还包含其他模型的kto训练脚本。


一、官方代码测试

trl官方kto.py样例脚本
官方脚本中使用了kto-mix-14k数据集对qwen1.5-1.8b-sft模型进行了测试。
其中kto-mix-14k数据集为kto算法适配的数据集,每条数据包括一个输入prompt、一个输出completion和一个输出是否可以被接受label。如下图所示:
在这里插入图片描述

1.kto脚本


from dataclasses import dataclass

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format


# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The arguments for the KTO training script.
    """

    dataset_name: str = "/mnt/data3/xxxxxxxx/kto-mix-14k"


if __name__ == "__main__":
    parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
    script_args, kto_args, model_args = parser.parse_args_into_dataclasses()

    # Load a pretrained model
    model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
    model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # If we are aligning a base model, we use ChatML as the default template
    if tokenizer.chat_template is None:
        model, tokenizer = setup_chat_format(model, tokenizer)

    # Load the dataset
    dataset = load_dataset(script_args.dataset_name)

    # Apply chat template
    def 
  • 12
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值