0. 引言
在使用 Hugging Face 的 transformers
库进行模型训练时,如果你希望忽略某些特殊标签/token的损失计算,可以通过在计算损失时屏蔽特定 token 的贡献来实现的。下面介绍一些方法,仅供参考。
1. 使用 ignore_index
选项
在 transformers
库中,损失计算通常是通过 CrossEntropyLoss
完成的。CrossEntropyLoss
有一个 ignore_index
参数,允许你指定某些 token 的损失不被计算。你可以将特殊标签的索引设置为 ignore_index
,从而忽略这些 token 的损失计算。
示例代码:
假设你有一个特殊的标签 token ID 为 SPECIAL_TOKEN_ID
,需要忽略 loss 计算。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 假设你已经加载了模型和tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 定义损失函数,忽略SPECIAL_TOKEN_ID的损失
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=SPECIAL_TOKEN_ID)
# 假设 input_ids 和 labels 是你的输入和目标输出
outputs = model(input_ids=input_ids, labels=labels)
logits = outputs.logits
# 计算损失,但忽略标签为 SPECIAL_TOKEN_ID 的token
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# 反向传播
loss.backward()
在这个例子中,CrossEntropyLoss
的 ignore_index
被设置为 SPECIAL_TOKEN_ID
,因此当 labels
中的某个位置包含 SPECIAL_TOKEN_ID
时,该位置的损失将不会被计算。
2. 手动设置标签
在某些情况下,你可能需要手动设置需要忽略的标签。将需要忽略的标签位置设置为 ignore_index
对应的值(通常是 -100
,这是 Hugging Face 默认的 ignore_index
值)。
示例代码:
# 特殊标签 ID
SPECIAL_TOKEN_ID = 12345
# 将 labels 中的 SPECIAL_TOKEN_ID 替换为 -100
labels[labels == SPECIAL_TOKEN_ID] = -100
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
# 反向传播
loss.backward()
在这个代码中,我们手动将 labels
中所有等于 SPECIAL_TOKEN_ID
的位置替换为 -100
。在 Hugging Face 的 transformers
库中,CrossEntropyLoss
默认会忽略 -100
对应的损失计算。
3. 调整模型的 forward 方法
如果你希望更灵活地控制损失计算过程,可以自定义模型的 forward
方法或在训练循环中手动计算损失,并使用前述的 ignore_index
策略来忽略特定标签的损失。
示例代码:
class CustomModel(AutoModelForCausalLM):
def forward(self, input_ids, labels=None, **kwargs):
outputs = super().forward(input_ids, labels=labels, **kwargs)
logits = outputs.logits
if labels is not None:
# 手动计算损失,忽略特定标签
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
outputs = (loss,) + outputs[1:]
return outputs
在这个自定义的模型中,forward
方法手动计算损失,并使用 ignore_index=-100
忽略特定标签的损失计算。
4. 总结
通过在 Hugging Face 的 transformers
库中使用 ignore_index
或者手动设置 labels
的方式,可以有效忽略特殊标签的损失计算,从而确保这些标签不会影响模型的训练。
欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;
欢迎关注知乎/CSDN:SmallerFL
也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤