Instruction tuning中loss怎么算的
# 训练函数
def train_red_model(model, train_dataloader, epochs=3, lr=6e-3):
optimizer = AdamW(model.parameters(), lr=lr)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(epochs):
model.train()
total_loss = 0
for batch in train_dataloader:
input_ids = batch["input_ids"].squeeze(1).to(device) # 移除多余维度
attention_mask = batch["attention_mask"].squeeze(1).to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) # 自回归任务
loss = outputs.loss
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss:.4f}")这种instruction tuning的loss是怎么算的,详细解释一下
在代码中,train_red_model
函数使用了一个自回归语言建模任务的损失计算方式,具体体现在 outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
和 loss = outputs.loss
这两行。这是一种常见的 instruction tuning 损失计算方法,适用于像Qwen2.5-3B这样的自回归语言模型(causal language model)。下面将详细解释这种损失的计算原理、背后的逻辑,以及它在instruction tuning中的意义。
1. 自回归语言建模的基本概念
自回归语言模型(例如Qwen2.5-3B、GPT系列)是通过预测序列中下一个token来训练的。给定一个输入序列 (
x
=
[
x
1
,
x
2
,
…
,
x
n
]
x = [x_1, x_2, \dots, x_n]
x=[x1,x2,…,xn] ),模型的目标是学习条件概率分布:
P
(
x
t
∣
x
1
,
x
2
,
…
,
x
t
−
1
)
P(x_t | x_1, x_2, \dots, x_{t-1})
P(xt∣x1,x2,…,xt−1)
其中 (
x
t
x_t
xt ) 是第 (
t
t
t ) 个token,基于前面的上下文 (
x
1
,
x
2
,
…
,
x
t
−
1
x_1, x_2, \dots, x_{t-1}
x1,x2,…,xt−1 ) 进行预测。
训练时,模型通过最大化整个序列的对数似然来优化参数,即最小化负对数似然损失(Negative Log-Likelihood, NLL):
L
=
−
1
n
∑
t
=
1
n
log
P
(
x
t
∣
x
1
,
x
2
,
…
,
x
t
−
1
)
\mathcal{L} = -\frac{1}{n} \sum_{t=1}^{n} \log P(x_t | x_1, x_2, \dots, x_{t-1})
L=−n1t=1∑nlogP(xt∣x1,x2,…,xt−1)
2. Instruction Tuning中的数据与任务
在instruction tuning中,数据通常是对话或指令-回复对的形式。例如,你的 allenai/tulu-3-sft-mixture
数据集中,一个样本可能如下:
- 输入:
"Instruction: Write a function to find palindromes\nResponse: def find_palindrome_titles..."
为了适配自回归语言模型,我们将输入和输出拼接为一个完整序列,并要求模型预测整个序列的token。这种方式假设模型不仅要理解指令,还要生成对应的回复。训练时,输入和标签是同一个序列(labels=input_ids
),但通过掩码机制确保模型只预测未来的token。
3. 代码中的损失计算细节
在代码中:
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
model
是基于 AutoModelForCausalLM
的Qwen2.5-3B模型,调用其 forward
方法时,labels=input_ids
表示将输入序列同时作为预测目标。下面是损失计算的逐步拆解:
输入与标签
input_ids
:形状为[batch_size, seq_len]
,表示token化的输入序列。例如:[Instruction, :, Write, a, function, ..., Response, :, def, find, ...]
attention_mask
:形状为[batch_size, seq_len]
,标记有效token(1)和填充token(0)。labels
:设置为input_ids
,即模型的目标是重建整个输入序列。
自回归掩码
自回归模型使用因果掩码(causal mask),确保在预测第 ( t t t ) 个token时,只依赖前 ( t − 1 t-1 t−1 ) 个token。这种掩码由Transformer的注意力机制自动实现:
- 对于位置 ( t t t ),注意力只能看到 ( [ 1 , 2 , … , t − 1 ] [1, 2, \dots, t-1] [1,2,…,t−1] ) 的token,后续token被屏蔽。
- 例如,给定序列
[A, B, C]
:- 预测
A
时无输入(仅初始化)。 - 预测
B
时输入A
。 - 预测
C
时输入A, B
。
- 预测
前向传播与Logits
模型的 forward
方法返回一个对象(通常是 CausalLMOutputWithPast
),包含:
logits
:形状为[batch_size, seq_len, vocab_size]
,表示每个位置的token预测概率分布。loss
:如果提供了labels
,会自动计算损失。
对于每个位置 ( t t t ):
- 输入:( x 1 , x 2 , … , x t − 1 x_1, x_2, \dots, x_{t-1} x1,x2,…,xt−1 )
- 输出:( logits t ∈ R vocab_size \text{logits}_t \in \mathbb{R}^{\text{vocab\_size}} logitst∈Rvocab_size )
- 目标:(
x
t
x_t
xt )(来自
labels
)
交叉熵损失
outputs.loss
是由Hugging Face内置的交叉熵损失函数计算的:
L
=
−
1
N
∑
i
=
1
B
∑
t
=
1
S
1
mask
i
,
t
=
1
log
P
(
x
i
,
t
∣
x
i
,
1
,
…
,
x
i
,
t
−
1
)
\mathcal{L} = -\frac{1}{N} \sum_{i=1}^{B} \sum_{t=1}^{S} \mathbb{1}_{\text{mask}_{i,t}=1} \log P(x_{i,t} | x_{i,1}, \dots, x_{i,t-1})
L=−N1i=1∑Bt=1∑S1maski,t=1logP(xi,t∣xi,1,…,xi,t−1)
其中:
- (
B
B
B ) 是批大小(
batch_size
)。 - (
S
S
S ) 是序列长度(
seq_len
)。 - (
1
mask
i
,
t
=
1
\mathbb{1}_{\text{mask}_{i,t}=1}
1maski,t=1 ) 表示只对有效token(由
attention_mask
指定)计算损失,忽略padding部分。 - (
P
(
x
i
,
t
∣
x
i
,
1
,
…
,
x
i
,
t
−
1
)
P(x_{i,t} | x_{i,1}, \dots, x_{i,t-1})
P(xi,t∣xi,1,…,xi,t−1) ) 是模型预测的概率,通过softmax从
logits
得到:
P ( x i , t = k ) = exp ( logits i , t , k ) ∑ j = 1 vocab_size exp ( logits i , t , j ) P(x_{i,t} = k) = \frac{\exp(\text{logits}_{i,t,k})}{\sum_{j=1}^{\text{vocab\_size}} \exp(\text{logits}_{i,t,j})} P(xi,t=k)=∑j=1vocab_sizeexp(logitsi,t,j)exp(logitsi,t,k)
其中 ( k k k ) 是目标token的索引。
损失的实际计算
- Shift操作:
input_ids
被用作输入,labels
被用作目标,但labels
会被左移一位(shifted),因为我们预测的是下一个token。- 例如,输入
[A, B, C]
,labels
会被处理为[B, C, <ignore>]
,而logits
只计算到倒数第二个位置。
- 忽略填充:
attention_mask
确保损失只计算有效token部分。例如,若序列中有padding([A, B, <pad>]
),则只对[A, B]
计算损失。
- 平均:
- 最终损失是对所有有效token的平均负对数似然。
4. Instruction Tuning中的意义
在instruction tuning中,这种损失计算方式有以下特点:
自监督性质
- 通过将
labels=input_ids
,模型被训练为“自重建”整个序列,包括指令和回复。这本质上是一个监督形式的自回归任务。 - 对于例子:
模型会逐步预测每个token,从Instruction: Write a function to find palindromes\nResponse: def find_palindrome_titles...
Instruction
到最后的...
,确保生成的回复符合指令。
条件生成
- 由于指令在前,回复在后,模型学习到 ( P ( Response ∣ Instruction ) P(\text{Response} | \text{Instruction}) P(Response∣Instruction) )。训练时,指令部分的token被用来条件化回复部分的生成。
- 例如,模型在看到
Instruction: Write a function
后,预测Response: def
,逐步生成完整代码。
5. 与传统分类任务的区别
相比分类任务(如RoBERTa在GLUE上的交叉熵损失),instruction tuning的损失有以下不同:
- 目标:分类任务预测单一类别(如正面/负面),而这里预测整个序列。
- 维度:分类损失基于
[batch_size, num_labels]
的logits,而自回归损失基于[batch_size, seq_len, vocab_size]
。 - 掩码:自回归任务使用因果掩码和注意力掩码,确保单向预测和忽略填充。
6. 代码中的具体实现细节
squeeze(1)
input_ids
和attention_mask
的初始形状可能是[batch_size, 1, seq_len]
(因DataLoader返回的张量格式),squeeze(1)
移除多余的维度,变为[batch_size, seq_len]
。
outputs.loss
- Hugging Face的
AutoModelForCausalLM
自动处理了shift、掩码和损失计算。你无需手动实现,只需确保labels
和input_ids
一致。
损失值含义
loss.item()
是标量,表示当前batch的平均负对数似然。越低表示模型预测越接近目标序列。
7. 可能的优化与调试
- 检查数据:确保
input_ids
和attention_mask
正确对齐,避免padding干扰损失。 - 调整学习率:6e-3可能偏高,若损失不收敛,可尝试1e-4或5e-5。
- 添加验证:定期计算困惑度(
exp(loss)
),评估模型在验证集上的生成能力。
总结
在 train_red_model
函数中,损失是基于自回归语言建模的交叉熵损失,计算模型预测下一个token的准确性。结合instruction tuning,模型通过这种方式学习生成与指令匹配的回复。这种方法在对话、代码生成等任务中非常强大,体现了指令微调的核心思想。
深入解析 Open-Instruct 中 Instruction Tuning 的 Loss 计算
介绍基于 open-instruct
框架的SFT脚本中instruction tuning的损失(loss)是如何计算的。会逐步拆解代码中的关键部分,解释其背后的原理,并结合上下文分析其在指令微调中的意义。这篇博客面向熟悉深度学习和自然语言处理的读者。
随着大语言模型(LLM)的广泛应用,指令微调(Instruction Tuning)成为提升模型任务适配能力的重要方法。在AllenAI开发的 open-instruct
框架中,针对自回归语言模型(如LLaMA、Qwen等)的监督微调(Supervised Fine-Tuning, SFT)实现了一个高效的训练流程。本文将聚焦于该框架中损失(loss)的计算逻辑,结合代码逐步分析其实现细节,并探讨其在instruction tuning中的作用。
数据准备与输入格式
在 open-instruct
的SFT脚本中,训练数据通常来源于对话格式的数据集(如 allenai/tulu-3-sft-mixture
),每个样本包含一个 messages
字段,表示多轮对话。例如:
{
"messages": [
{"role": "user", "content": "Write a function to find palindromes"},
{"role": "assistant", "content": "def find_palindrome_titles..."}
]
}
数据预处理由 encode_sft_example
函数完成,其核心是将对话转换为适合自回归训练的输入-标签对:
input_ids = tokenizer.apply_chat_template(
conversation=messages,
tokenize=True,
return_tensors="pt",
padding=False,
truncation=True,
max_length=max_seq_length,
add_generation_prompt=False,
)
labels = input_ids.clone()
input_ids
:通过apply_chat_template
将对话序列化为token ID序列。例如,结果可能是[Instruction, :, Write, ..., Response, :, def, ...]
。labels
:初始时与input_ids
相同,但后续会通过掩码调整。
关键点在于,labels
中非助手(assistant)部分的token会被设置为 -100
,以排除其对损失的贡献:
for message_idx, message in enumerate(messages):
if message["role"] != "assistant":
# 计算非助手部分的起始和结束索引
message_start_idx = ... # 前文token数
message_end_idx = ... # 当前消息结束位置
labels[:, message_start_idx:message_end_idx] = -100
这种掩码机制确保模型只对助手的回复部分计算损失,而指令部分仅作为上下文,不参与梯度更新。
训练循环中的损失计算
训练循环位于 main
函数中,核心逻辑如下:
for epoch in range(starting_epoch, args.num_train_epochs):
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch, use_cache=False)
if args.reduce_loss == "mean":
loss = outputs.loss
else: # "sum"
logits = outputs.logits
labels = batch["labels"]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
shift_logits = shift_logits.view(-1, embedding_size)
shift_labels = shift_labels.view(-1)
loss = loss_fct(shift_logits, shift_labels)
total_loss += loss.detach().float()
accelerator.backward(loss)
输入与输出
batch
:包含input_ids
、labels
和attention_mask
,形状为[batch_size, seq_len]
。outputs
:调用model(**batch)
返回一个CausalLMOutputWithPast
对象,包含:logits
:形状[batch_size, seq_len, vocab_size]
,表示每个位置的token预测概率。loss
:如果提供了labels
,则自动计算的损失。
两种损失计算方式
脚本支持两种损失缩减方式,由 args.reduce_loss
参数控制(默认 "mean"
):
-
"mean"
模式- 直接使用
outputs.loss
,这是Hugging Facetransformers
内置的自回归损失:
L mean = − 1 N ∑ i = 1 B ∑ t = 1 S 1 labels i , t ≠ − 100 log P ( x i , t ∣ x i , 1 , … , x i , t − 1 ) \mathcal{L}_{\text{mean}} = -\frac{1}{N} \sum_{i=1}^{B} \sum_{t=1}^{S} \mathbb{1}_{\text{labels}_{i,t} \neq -100} \log P(x_{i,t} | x_{i,1}, \dots, x_{i,t-1}) Lmean=−N1i=1∑Bt=1∑S1labelsi,t=−100logP(xi,t∣xi,1,…,xi,t−1) - 其中:
- ( B B B ):批大小。
- ( S S S ):序列长度。
- (
N
N
N ):有效token数(
labels != -100
的数量)。 - (
P
(
x
i
,
t
∣
.
.
.
)
P(x_{i,t} | ...)
P(xi,t∣...) ):通过
logits
的softmax计算。
- 特点:
- 内置损失自动处理了因果掩码(只预测下一个token)和标签掩码(忽略
-100
)。 - 对每个有效token的损失取平均,适合均衡样本权重。
- 内置损失自动处理了因果掩码(只预测下一个token)和标签掩码(忽略
- 直接使用
-
"sum"
模式- 手动计算交叉熵损失, summing 而非 averaging:
shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss(reduction="sum") loss = loss_fct(shift_logits.view(-1, embedding_size), shift_labels.view(-1))
- 步骤:
- Shift操作:
logits
去掉最后一位(无对应标签),labels
去掉第一位(无前文预测),对齐预测与目标。 - 展平:将张量变为
[batch_size * (seq_len-1), vocab_size]
和[batch_size * (seq_len-1)]
。 - 交叉熵:
L sum = − ∑ i = 1 B ∑ t = 1 S − 1 1 labels i , t ≠ − 100 log P ( x i , t ∣ x i , 1 , … , x i , t − 1 ) \mathcal{L}_{\text{sum}} = -\sum_{i=1}^{B} \sum_{t=1}^{S-1} \mathbb{1}_{\text{labels}_{i,t} \neq -100} \log P(x_{i,t} | x_{i,1}, \dots, x_{i,t-1}) Lsum=−i=1∑Bt=1∑S−11labelsi,t=−100logP(xi,t∣xi,1,…,xi,t−1)
- Shift操作:
- 特点:
- 对所有有效token的损失求和,而不是平均。
- 在高梯度累积(gradient accumulation)场景下,能更平等地对待每个token,避免样本长度的影响(详见后文讨论)。
- 手动计算交叉熵损失, summing 而非 averaging:
梯度计算
accelerator.backward(loss)
:根据损失计算梯度,结合accumulate
实现梯度累积。- 仅对
labels != -100
的token(如助手回复)产生梯度贡献。
损失的数学原理
自回归语言模型的目标是最大化序列的对数似然:
L
=
−
∑
t
=
1
S
log
P
(
x
t
∣
x
1
,
…
,
x
t
−
1
)
\mathcal{L} = -\sum_{t=1}^{S} \log P(x_t | x_1, \dots, x_{t-1})
L=−t=1∑SlogP(xt∣x1,…,xt−1)
在instruction tuning中:
- 输入序列包含指令和回复。
labels
中的掩码(-100
)限制损失仅计算回复部分。- 例如,对于序列
[Instruction, :, Write, Response, :, def]
:input_ids
:[1, 2, 3, 4, 5, 6]
labels
:[-100, -100, -100, 4, 5, 6]
- 损失仅基于
[4, 5, 6]
的预测。
"mean"
vs "sum"
的差异:
"mean"
:归一化到有效token数,适合样本长度差异较大时。"sum"
:保留总损失规模,可能在长序列中放大影响,论文中提到这在对话任务(如AlpacaEval)中可提升5个百分点。
Instruction Tuning 中的意义
掩码的作用
通过将非助手部分的 labels
设置为 -100
,模型专注于学习生成回复,而非重复指令。这种方式模拟了对话场景下的条件生成:
P
(
Response
∣
Instruction
)
P(\text{Response} | \text{Instruction})
P(Response∣Instruction)
- 指令作为上下文,提供生成的基础。
- 损失仅优化回复的生成质量。
"sum"
模式的优势
代码注释提到,使用 "sum"
可改进聊天模型性能(见 https://github.com/huggingface/transformers/issues/24725
)。原因在于:
- 在高梯度累积下,
"mean"
会平均每个样本的损失,导致短样本的影响被放大。 "sum"
按token数加权,确保长序列的训练信号不被稀释。
代码中的关键实现细节
-
数据加载:
DataCollatorForSeq2Seq
动态填充批次,确保input_ids
和labels
对齐。attention_mask
标记有效token,防止padding干扰。
-
Shift操作:
- 自回归模型预测下一个token,因此
logits
和labels
需要错位对齐。
- 自回归模型预测下一个token,因此
-
加速器支持:
accelerator
管理分布式训练和梯度累积,total_loss
在主进程上聚合。
-
日志与监控:
- 每
logging_steps
步记录平均损失和token处理速度(TPS)。
- 每
调试与优化建议
- 检查掩码:打印
labels
和input_ids
,确保非助手部分正确置为-100
。 - 损失异常:若损失过高,可能是学习率(默认
2e-5
)需调整,或数据预处理有误。 - 对比实验:尝试
"mean"
和"sum"
,观察在你的数据集(如tulu-3-sft-mixture
)上的性能差异。
结语
open-instruct
的SFT脚本通过自回归损失实现了instruction tuning的核心目标:让模型根据指令生成高质量回复。其损失计算结合了掩码机制(聚焦助手部分)和灵活的缩减方式("mean"
或 "sum"
),既高效又适应多种场景。理解这些细节不仅有助于复现训练流程,还能为自定义微调策略提供启发。未来,你可以尝试调整掩码逻辑或损失权重,进一步优化模型在特定任务上的表现。
logits和labels移位解释
在自回归语言模型的训练中,logits
去掉最后一位和 labels
去掉第一位的“shift操作”是损失计算的核心步骤,用来对齐模型的预测与目标。这里的操作看似简单,但背后有深刻的逻辑,与自回归任务的预测机制紧密相关。下面我将详细解释为什么 logits
需要去掉最后一位,以及这一操作的目的和意义。
自回归语言建模的基本原理
自回归语言模型(如Qwen2.5-3B、LLaMA等)通过逐步预测序列中的下一个token来生成文本。给定一个输入序列 (
x
=
[
x
1
,
x
2
,
…
,
x
n
]
x = [x_1, x_2, \dots, x_n]
x=[x1,x2,…,xn] ),模型的目标是学习条件概率分布:
P
(
x
t
∣
x
1
,
x
2
,
…
,
x
t
−
1
)
P(x_t | x_1, x_2, \dots, x_{t-1})
P(xt∣x1,x2,…,xt−1)
其中:
- ( x t x_t xt ) 是第 ( t t t ) 个token。
- ( x 1 , x 2 , … , x t − 1 x_1, x_2, \dots, x_{t-1} x1,x2,…,xt−1 ) 是前文上下文。
在训练时,模型一次性处理整个序列,输出每个位置的预测概率(logits
),然后与目标序列(labels
)比较,计算损失。但由于自回归的因果性质,模型在每个位置 (
t
t
t ) 的预测是基于前 (
t
−
1
t-1
t−1 ) 个token的,因此需要对齐输入和输出的维度。
输入与输出的维度
假设输入序列为 [A, B, C]
(长度为3):
input_ids
:[A, B, C]
,形状[batch_size, seq_len]
,这里seq_len = 3
。labels
:通常与input_ids
相同,[A, B, C]
,因为自回归任务的目标是重建输入序列。
模型处理后:
logits
:形状[batch_size, seq_len, vocab_size]
,即[3, vocab_size]
,为每个位置生成一个词汇表大小的预测分布。
具体来说:
- 位置 0(输入
[A]
):预测下一个token(应为B
),输出logits_0
。 - 位置 1(输入
[A, B]
):预测下一个token(应为C
),输出logits_1
。 - 位置 2(输入
[A, B, C]
):预测下一个token(应为序列外的未知token),输出logits_2
。
为什么需要 Shift 操作?
问题在于:logits
的最后一个位置没有对应的真实目标。让我们逐步分析:
-
logits
的含义:logits
表示模型在每个位置对下一个token的预测。- 对于长度为 (
n
n
n ) 的序列,
logits
有 ( n n n ) 个位置,分别对应:logits[0]
:预测 ( x 2 x_2 x2 )(基于 ( x 1 x_1 x1 ))。logits[1]
:预测 ( x 3 x_3 x3 )(基于 ( x 1 , x 2 x_1, x_2 x1,x2 ))。- …
logits[n-1]
:预测 ( x n x_n xn )(基于 ( x 1 , … , x n − 1 x_1, \dots, x_{n-1} x1,…,xn−1 ))。logits[n]
:预测序列外的下一个token(基于 ( x 1 , … , x n x_1, \dots, x_n x1,…,xn ))。
-
labels
的含义:labels
是我们希望模型预测的目标序列,通常与input_ids
相同。- 对于
[A, B, C]
,labels = [A, B, C]
,但在自回归任务中,我们关心的是“下一个token”:- 基于
[A]
预测B
。 - 基于
[A, B]
预测C
。 - 基于
[A, B, C]
预测什么?——没有后续token可用。
- 基于
-
维度不对齐:
logits
有 ( n ) 个位置([logits_0, logits_1, logits_2]
)。- 但有效目标只有 ( n-1 ) 个(
[B, C]
),因为最后一个logits_2
预测的是序列外的token,而训练数据中没有提供这个目标。
Shift 操作的具体作用
为了解决上述问题,shift_logits
和 shift_labels
通过“错位”对齐预测和目标:
代码中的实现
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
-
shift_logits = logits[..., :-1, :]
:- 去掉
logits
的最后一位。 - 原
logits
:[logits_0, logits_1, logits_2]
(形状[batch_size, 3, vocab_size]
)。 - 后
shift_logits
:[logits_0, logits_1]
(形状[batch_size, 2, vocab_size]
)。 - 原因:
logits_2
是基于[A, B, C]
预测的下一个token,但训练数据中没有对应的真实目标(序列到C
结束),所以丢弃它。
- 去掉
-
shift_labels = labels[..., 1:]
:- 去掉
labels
的第一位。 - 原
labels
:[A, B, C]
(形状[batch_size, 3]
)。 - 后
shift_labels
:[B, C]
(形状[batch_size, 2]
)。 - 原因:
labels
的第一位A
没有前文可预测(模型从空序列开始无法直接生成A
),而我们关心的是基于前文预测的后续token。
- 去掉
对齐后的结果
shift_logits
:[logits_0, logits_1]
。logits_0
:预测B
。logits_1
:预测C
。
shift_labels
:[B, C]
。- 目标
B
:与logits_0
对齐。 - 目标
C
:与logits_1
对齐。
- 目标
现在,shift_logits
和 shift_labels
的长度一致(均为 (
n
−
1
n-1
n−1 )),可以直接用于交叉熵损失计算:
L
=
−
∑
t
=
1
n
−
1
log
P
(
x
t
+
1
∣
x
1
,
…
,
x
t
)
\mathcal{L} = -\sum_{t=1}^{n-1} \log P(x_{t+1} | x_1, \dots, x_t)
L=−t=1∑n−1logP(xt+1∣x1,…,xt)
图解说明
以 [A, B, C]
为例:
输入序列: [A, B, C]
logits: [logits_0, logits_1, logits_2]
|预测B |预测C |预测下一个未知token
labels: [A, B, C]
shift_logits: [logits_0, logits_1]
shift_labels: [B, C]
logits_0
(基于[A]
)预测B
,与shift_labels[0]
(B
)比较。logits_1
(基于[A, B]
)预测C
,与shift_labels[1]
(C
)比较。logits_2
(基于[A, B, C]
)无对应目标,丢弃。
为什么 logits
去掉最后一位是必要的?
-
无目标问题:
logits[n-1]
(即logits_2
)预测的是序列外的下一个token,但训练数据只提供到C
,没有后续token的真实值。- 如果保留
logits_2
,无法为其分配一个合理的label
,会导致损失计算出错或引入噪声。
-
因果性:
- 自回归模型的注意力机制使用因果掩码(causal mask),确保位置 ( t t t ) 只依赖 ( [ 1 , … , t − 1 ] [1, \dots, t-1] [1,…,t−1] )。
- 但
logits
的最后一个位置仍然会生成预测,只是没有意义(因为没有监督信号)。
-
损失计算需求:
- 交叉熵损失要求预测(
logits
)和目标(labels
)的维度匹配。 - 去掉
logits
最后一位后,shift_logits
和shift_labels
的长度均为 ( n − 1 n-1 n−1 ),完美对齐。
- 交叉熵损失要求预测(
在 Instruction Tuning 中的意义
在instruction tuning中,序列通常包含指令和回复,例如:
[Instruction, :, Write, Response, :, def, find]
-
input_ids
:[1, 2, 3, 4, 5, 6, 7]
。 -
labels
(掩码后):[-100, -100, -100, 4, 5, 6, 7]
(只训练回复部分)。 -
logits
:[l_0, l_1, l_2, l_3, l_4, l_5, l_6]
。 -
shift_logits
:[l_0, l_1, l_2, l_3, l_4, l_5]
。 -
shift_labels
:[-100, -100, -100, 4, 5, 6]
。 -
l_3
预测4
(Response
),l_4
预测5
(:
),依此类推。 -
最后一位
l_6
(基于[1, 2, 3, 4, 5, 6, 7]
)预测序列外的token,无对应目标,故丢弃。
掩码(-100
)进一步确保损失只计算回复部分的预测,而 logits
去掉最后一位则是为了保持维度一致性。
总结
logits
去掉最后一位是为了剔除没有真实目标的预测,使其与 shift_labels
对齐。这种“shift操作”是自回归语言建模的标准做法,确保损失计算只基于有效的前 ( n-1 ) 个预测。结合instruction tuning的掩码机制,这一操作让模型专注于生成回复,同时保持计算的正确性和效率。理解这一点有助于深入掌握自回归训练的细节,也为调试和优化提供了基础。
后记
2025年3月25日14点15分于上海,在grok 3大模型辅助下完成。