从零实现强化学习DPO(SimPO)训练代码

强化学习已经越来越被大家所熟知了,从最开始的PPO到现在的各种DPO及其相关变体。对于SFT,理论上讲是让模型的输出更符合规范,而对于强化学习来说,应该就是让模型知道什么是不可以输出的。
现在我们从零开始,进行一下强化学习DPO的全过程代码实现。

1、加载模型。

这一步就是需要再huggingface加载成熟的模型了,后续使用过程中会用到输入给模型tensor,模型输出logits的过程,下面我们开始加载模型并简单的示例一下输出logits的过程。我们以B站的1.9B小模型为例开始:
了解过DPO的原理会知道DPO训练时会有两个模型,分别是policy model和ref model,前者就是我们要进行训练的模型,而后者就是未训练前的模型,用以规范模型的输出。

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
# 1、加载模型与tokenizer
model_path = 'IndexTeam/Index-1___9B-Chat'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to('cuda:3')
ref_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to('cuda:3')
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)

在DPO的过程中我们需要将输入给模型,使其输出logits,示例如下:
我们随便输入一个2批次的tensor给模型:

data = torch.tensor([[1,200,33,333],[1,2,3,4]]).to("cuda:3")
out = model(data)
print(out.logits.shape)

其输出为一个CausalLMOutputWithPast特有类型,其中的keys包括[‘logits’, ‘past_key_values’],我们需要提取出logits即可,上述代码的输出如下,对应的为batch、sequence_len、vocab_size:

torch.Size([2, 4, 65029])

2、处理数据

我们知道DPO的数据正常会有三个字段,如下:

  • prompt
  • chosen
  • rejected

我们本次使用的数据类似如下,jsonl格式,每一行都有对应的prompt、chosen、rejected:

{"prompt":"What is the primary goal of Unsloth for LLM fine-tuning?","chosen":"The primary goal of Unsloth for LLM fine-tuning is to accelerate the process, achieving a 2x speedup while maintaining 0% accuracy degradation compared to normal QLoRA.","rejected":"The primary goal of Unsloth for LLM fine-tuning is to slow down the process and increase memory usage."}

其中输入给模型的分别是prompt+chosen作为一个chosen的完整字段,prompt+rejected作为一个rejected的完整字段。下面我们开始进行数据的处理,首先进行一下我们dataset的编写:

class RlhfDataset(Dataset):
    def __init__(self, file_path, tokenizer):
        with open(file_path, "r", encoding="utf-8") as file:
            data_list = file.readlines()
        self.data_list = data_list
        self.tokenizer = tokenizer

    def __getitem__(self, item):
        data = self.data_list[item]
        data = json.loads(data)
        prompt = data['prompt']
        chosen = data['chosen']
        rejected = data['rejected']

        chosen_full_text = f"{prompt}\n\n### Response:\n{chosen}"
        rejected_full_text = f"{prompt}\n\n### Response:\n{rejected}"

        prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
        chosen_full_tokens = self.tokenizer.encode(chosen_full_text, add_special_tokens=False)
        rejected_full_tokens = self.tokenizer.encode(rejected_full_text, add_special_tokens=False)

        input = {
                "prompt": prompt_tokens,
                "chosen": chosen_full_tokens,
                "rejected": rejected_full_tokens,
            }
        return input

    def __len__(self):
        return len(self.data_list)

这里我就直接内置了一个chat template。

然后再进行collate的编写:

def data_collate(batch, pad_token_id, device, max_length=None, if_mask_prompt=True):
    batch_data = {
        "prompt": [],
        "chosen": [],
        "rejected": [],
        "rejected_mask": [],
        "chosen_mask": []
    }

    # 判断长度及padding
    max_length_common = 0
    for key in ["chosen", "rejected"]:
        current_max = max(len(item[key]) for item in batch)
        max_length_common = max(max_length_common, current_max)

    # 转为torch tensor并padding,决定是否对prompt进行mask
    for item in batch:
        prompt = torch.tensor(item['prompt'])
        batch_data['prompt'].append(prompt)

        for key in ["chosen", "rejected"]:
            out = item[key]
            out_padding = out + [pad_token_id] * (max_length_common - len(out))
            mask = torch.ones(len(out_padding)).bool()

            # padding部分的mask设置为 IGNORE_INDEX
            mask[len(out):] = IGNORE_INDEX

            if if_mask_prompt:
                mask[:prompt.shape[0] + 2] = IGNORE_INDEX
            batch_data[key].append(torch.tensor(out_padding))
            batch_data[f"{key}_mask"].append(mask)

    # 进行最大长度截断
    for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
        tensor_stack = torch.stack(batch_data[key])
        if max_length is not None:
            tensor_stack = tensor_stack[:, :max_length]
        # 将tensor移到对应的device
        batch_data[key] = tensor_stack.to(device)
    return batch_data

最后我们就可以划分数据了:

# 加载数据
data_file = './unsloth_dpo.jsonl'
dataset = RlhfDataset(data_file, tokenizer)
# 划分训练集验证集
train_size = int(len(dataset) * 0.85)  # 85% for training
val_size = len(dataset) - train_size  # Remaining for validation
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# 设置相关参数
batch_size = 4
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=True,
    drop_last=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False
)

3、开始编写DPO、SimPO的损失函数

损失函数基本属于最重要的一部分了。

3.1 DPO loss、SimPo loss计算函数编写

这里我们假设已经知道模型输出的Log probabilities,那么两个loss的计算就可以如下表示。
关于DPO loss公式可见下图:
在这里插入图片描述
而对于SimPO,其实是很简单的变体,相较于DPO加了一个gamma参数,且去掉了ref模型制约,并且因为公式中有|y|,也就是说Log probabilities要求平均操作了。公式如下
在这里插入图片描述
最后计算loss时的Log probabilities有两种方式,一种是对Log probabilities在每个单词上求平均(SimPO中),另一中就是对每个句子中单词的Log probabilities求和(DPO)。
TRL中的DPO是进行Log probabilities的求和,但我们这里就改一下进行求平均。这里我们先写loss的代码,两种方式的Log probabilities代码我们下一节实现。

class DPOLoss(nn.Module):
    """
    DPO Loss
    """
    def __init__(self, beta: float = 0.1) -> None:
        super().__init__()
        self.beta = beta

    def forward(
            self,
            policy_chosen_logps: torch.Tensor,
            policy_rejected_logps: torch.Tensor,
            reference_chosen_logps: torch.Tensor,
            reference_rejected_logps: torch.Tensor,
    ):
        """
        policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
        policy_rejected_logps:   Shape: (batch_size,)
        reference_chosen_logps: Shape: (batch_size,)
        reference_rejected_logps: Shape: (batch_size,)
        """
        policy_logps = policy_chosen_logps - policy_rejected_logps
        reference_logps = reference_chosen_logps - reference_rejected_logps
        logits = policy_logps - reference_logps

        loss = -F.logsigmoid(self.beta * logits)

        # 下面两个用于追踪训练的进度
        chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach()
        rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach()

        # 对每个batch进行平均
        return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()
class SimPo(nn.Module):
    """
    SimPO Loss
    """

    def __init__(self, beta: float = 0.1, gamma: float = 0.5) -> None:
        super().__init__()
        self.beta = beta
        self.gamma = gamma

    def forward(
            self,
            policy_chosen_logps: torch.Tensor,
            policy_rejected_logps: torch.Tensor,
    ):
        """
        policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
        policy_rejected_logps:   Shape: (batch_size,)
        """
        logits = policy_chosen_logps - policy_rejected_logps
        logits = logits - self.gamma
        loss = -F.logsigmoid(self.beta * logits)

        # 对每个batch进行平均(期望)
        return loss.mean()

3.2 Log probabilities计算

下面需要开始实现计算模型的Log probabilities。代码如下,这里的两个输入为logits, label,其中logits为label输入给模型后输出的结果,并且当前的logits是预测label中下一个词的,故需要进行位移操作:

def compute_logprobs(logits, labels, mask=None):
    """
    logits:  shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果
    labels:  shape (batch_size, sequence_len)
    """

    # 需要先进行位移操作
    # 去掉标签的第一个
    labels = labels[:, 1:].clone()
    # 去掉模型输出的最后一个
    logits = logits[:, :-1, :]

    logps = F.log_softmax(logits, dim=-1)

    select_logprobs = torch.gather(
        input=logps,
        dim=-1,
        index=labels.unsqueeze(1)
    ).squeeze(1)

    if mask is not None:
        mask = mask[:, 1:].clone()
        # 进行掩码padding部分
        select_logprobs = select_logprobs * mask
        # 计算平均
        average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)
        return average_logprobs
    else:
        return select_logprobs.mean(-1)

上面是已进行求平均的操作,即SimPO的实现,如果是TRL中DPO求和的操作话只需要将average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)改为average_logprobs = select_logprobs.sum(-1)即可。

其实上面这个函数最终的输出取负数就是F.cross_entropy(logits, targets) 交叉熵的输出,只不过添加了mask操作而已,下面我们可以通过这两种不同的方式实现计算,下面是使用F.cross_entropy进行计算的代码:

def compute_logprobs_f_cross(logits, labels, mask=None):
    """
    logits:  shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果
    labels:  shape (batch_size, sequence_len)
    """
    # 需要先进行位移操作
    # 去掉标签的第一个
    labels = labels[:, 1:].clone()
    # 去掉模型输出的最后一个
    logits = logits[:, :-1, :].clone()

    batch_size, sequence_len, vocab_size = logits.shape
    cross_entropy_loss = 0

    if mask is not None:
        mask = mask[:, 1:].clone()
        labels.masked_fill_(~mask, -100)
        for i in range(batch_size):
            cross_entropy_loss += F.cross_entropy(logits[i], labels[i])
    else:
        for i in range(batch_size):
            cross_entropy_loss += F.cross_entropy(logits[i], labels[i])
    cross_entropy_loss /= batch_size
    return cross_entropy_loss

最终我们进行一下测试即可得到如下结果:

logits = torch.tensor(
     [[2.0, 1.0, 0.1, 0.4],
      [0.5, 2.5, 0.3, 0.5],
      [0.6, 2.5, 0.3, 0.8],
      [0.5, 2.5, 0.6, 0.6]], dtype=torch.float32).unsqueeze(0)
 mask = torch.tensor([[True, True, False, False]])
 targets = torch.tensor([0, 1, 0, 2]).unsqueeze(0)
 loss1 = -compute_logprobs(logits, targets, mask)
 loss2 = compute_logprobs_f_cross(logits, targets, mask)
 print(loss1,loss2)
 ---------------------------
 tensor([1.5419]) tensor(1.5419)

要注意的是,F.cross_entropy中所计算的logits和target一般是不带batch的,例如Shape: (2, 3)与Shape: (2,),如下:

logits = torch.tensor(
    [[2.0, 1.0, 0.1],
     [0.5, 2.5, 0.3]])  # Shape: (2, 3)
targets = torch.tensor([0, 2])  # Shape: (2,)
  • logits:形状为 (2, 3) 的张量,表示两个样本的对数概率(logits)。每个样本有三个类别的对数概率。
  • targets:形状为 (2,) 的张量,表示每个样本的真实类别标签。第一个样本的真实类别是0,第二个样本的真实类别是2(大模型中这里真是类别就是其vocab_size)。

所以上述batch输入我们不能直接给进这个函数,需要每个batch进行计算。

3.2 最终batch计算loss

上面我们写好了loss的计算相关代码,下面只需要在batch层面使用上面写好的函数即可,代码如下:

def compute_batch_loss(batch, policy_model, reference_model, beta):
    # 决定使用哪个loss
    # loss_fn = SimPo(beta, 0.5)   SimPO loss
    loss_fn = DPOLoss(beta)   # DPO loss

    policy_chosen_logps = compute_logprobs(
        logits=policy_model(batch["chosen"]).logits,
        labels=batch["chosen"],
        mask=batch["chosen_mask"]
    )
    policy_rejected_logps = compute_logprobs(
        logits=policy_model(batch["rejected"]).logits,
        labels=batch["rejected"],
        mask=batch["rejected_mask"]
    )
    reference_chosen_logps = compute_logprobs(
        logits=reference_model(batch['chosen']).logits,
        labels=batch['chosen'],
        mask=batch["chosen_mask"]
    )
    reference_rejected_logps = compute_logprobs(
        logits=reference_model(batch['rejected']).logits,
        labels=batch['rejected'],
        mask=batch["rejected_mask"]
    )
    loss, chosen_rewards, rejected_rewards = loss_fn(
        policy_chosen_logps=policy_chosen_logps,
        policy_rejected_logps=policy_rejected_logps,
        reference_chosen_logps=reference_chosen_logps,
        reference_rejected_logps=reference_rejected_logps,
    )
    # SimPO使用如下
    # loss = loss_fn(
    #     policy_chosen_logps=policy_chosen_logps,
    #     policy_rejected_logps=policy_rejected_logps,
    # )
    # return loss
    return loss, chosen_rewards, rejected_rewards

4、开始训练

下面我们开始训练脚本的编写,不用Trainer确实比较麻烦,需要自己手动的epoch循环之类的,不过如果之前做过CV相关的话应该就不会陌生了。
下面是我们的训练函数

def train_model(
        policy_model, reference_model, train_loader, val_loader,
        optimizer, num_epochs, beta,
        eval_freq, eval_iter):
    tracking = {
        "train_losses": [],
        "train_chosen_rewards": [],
        "train_rejected_rewards": [],
        "val_losses": [],
        "val_chosen_rewards": [],
        "val_rejected_rewards": [],
        "tokens_seen": []
    }
    tokens_seen, global_step = 0, -1

    # 训练
    for epoch in range(num_epochs):
        # policy 模型需要训练
        policy_model.train()

        for idx, batch in enumerate(train_loader):
            optimizer.zero_grad()

            loss, chosen_rewards, rejected_rewards = compute_batch_loss(
                batch=batch,
                policy_model=policy_model,
                reference_model=reference_model,
                beta=beta
            )
            loss.backward()
            optimizer.step()

            global_step += 1
            tokens_seen += batch["chosen"].numel()

            # 验证
            if global_step % eval_freq == 0:
                res = evaluate_loss_dataloader(
                    policy_model=policy_model,
                    reference_model=reference_model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    beta=beta,
                    eval_iter=eval_iter
                )
                tracking["train_losses"].append(res["train_loss"])
                tracking["train_chosen_rewards"].append(res["train_chosen_reward"])
                tracking["train_rejected_rewards"].append(res["train_rejected_reward"])
                tracking["val_losses"].append(res["val_loss"])
                tracking["val_chosen_rewards"].append(res["val_chosen_reward"])
                tracking["val_rejected_rewards"].append(res["val_rejected_reward"])
                tracking["tokens_seen"].append(tokens_seen)
                train_reward_margin = res["train_chosen_reward"] - res["train_rejected_reward"]
                val_reward_margin = res["val_chosen_reward"] - res["val_rejected_reward"]

                print(
                    f"Ep {epoch + 1} (Step {global_step:06d}): "
                    f"Train loss {res['train_loss']:.3f}, Val loss {res['val_loss']:.3f}, "
                    f"Train reward margins {train_reward_margin:.3f}, "
                    f"Val reward margins {val_reward_margin:.3f}"
                )

    return tracking

训练函数已经写好了,我们开始训练:

def main():
    torch.manual_seed(42)
    start_time = time.time()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

    num_epochs = 3
    tracking = train_model(
        policy_model=model,
        reference_model=ref_model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        num_epochs=num_epochs,
        beta=0.1,  # value between 0.1 and 0.5
        eval_freq=2,
        eval_iter=2
    )

    end_time = time.time()
    execution_time_minutes = (end_time - start_time) / 60
    print(f"Training completed in {execution_time_minutes:.2f} minutes.")

训练脚本中我就没有设置保存模型的代码了,大家可以自行设置或者在jupter中运行后接着进行推理测试。

5、总结:实验结果

5.1 实验

实验的话我们就只对DPO进行测试了。

使用B站的Index-1___9B-Chat模型在HF上随便找了一个关于unsloth的数据集进行了一下测试,数据集如下:
在这里插入图片描述
是一个关于unsloth的问答,jsonl格式,每行都有prompt、chosen、rejected字段。

训练过程结果如下:
在这里插入图片描述
loss下降还是可以的,测试了几个用例输出,相较之前回答会不一样,有一点点变好吧。因为我们只有50条数据,且DPO之类的强化学习主要是减少bad的输出,而不是学习新知识,故提升不大也在合理范围内。

5.2 总结

上述只是我们简单的进行了DPO SimPO的loss实现及训练代码编写,只是一个demo示例,并没有增加分布式训练、模型chat template适配等等。仅供学习原理使用吧。本文的全部代码已保存至github下:DPO_example

如果想要使用DPO或者Simpo、CPO等强化学习方法真正训练的话,
可以使用本项目中构建的强化学习框架,支持deepspeed的单机多卡Lora、Dora、Qlora、全量参数训练,并自动适配模型的chat template:RLHF训练

  • 29
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值