前言
本文使用trl仓库的ktotrainer实现对Qwen模型的训练,使模型对齐人类偏好。之前在博客中(链接)已经大概讲解了KTO的算法思想,本文主要为实战部分。
本文首先使用trl的官方示例代码跑通整个流程,之后对数据集处理部分作了修改,修改后可以使用PPO、DPO相同的数据集进行kto训练。
本文的训练数据集为:comparison_gpt4_data_zh
模型为:Qwen-7B-Chat
文末有github仓库,仓库中还包含其他模型的kto训练脚本。
一、官方代码测试
trl官方kto.py样例脚本
官方脚本中使用了kto-mix-14k数据集对qwen1.5-1.8b-sft模型进行了测试。
其中kto-mix-14k数据集为kto算法适配的数据集,每条数据包括一个输入prompt、一个输出completion和一个输出是否可以被接受label。如下图所示:
1.kto脚本
from dataclasses import dataclass
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""
dataset_name: str = "/mnt/data3/xxxxxxxx/kto-mix-14k"
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
# Load the dataset
dataset = load_dataset(script_args.dataset_name)
# Apply chat template
def