前言
在上一章【课程总结】day19(下):Transformer源码深入理解总结中,我们对Transformer架构以及初始化部分做了梳理,本章我们将对Transformer训练过程进行代码分析理解。
训练流程
- 训练过程主要由四个主要部分组成:
- 第一部分:加载数据集。通过get_dataloader()加载数据集,这一过程与Seq2Seq类似,这里不再赘述。
- 第二部分:数据对齐。调用collate_fn()函数,对数据集整理并对齐。
- 第三部分:开始训练。调用run_epoch()函数,循环遍历dataloader中的数据集并进行训练。
- 第四部分:前向传播。在训练过程中,调用EncoderDecoder的forward()函数进行前向传播。
代码分析理解
数据对齐 collate_fn()
源码如下:
def collate_fn(batch, tokenizer):
input_sentences, input_sentence_lens, output_sentences, output_sentence_lens = zip(
*batch
)
# 转索引【按本批量最大长度来填充】
input_sentence_len = max(input_sentence_lens)
input_idxes = []
for input_sentence in input_sentences:
input_idxes.append(tokenizer.encode_input(input_sentence, input_sentence_len))
# 转索引【按本批量最大长度来填充】
output_sentence_len = max(output_sentence_lens)
output_idxes = []
for output_sentence in output_sentences:
output_idxes.append(
tokenizer.encode_output(output_sentence, output_sentence_len)
)
# 转张量 [batch_size, seq_len] src
input_idxes = torch.LongTensor(input_idxes)
# src_mask [batch_size, 1, seq_len]
input_mask = (input_idxes != tokenizer.input_word2idx.get("<PAD>")).unsqueeze(-2)
# tgt [batch_size, seq_len]
output_idxes = torch.LongTensor(output_idxes)
# tgt [batch_size, seq_len - 1] 去掉最后一个
output_idxes_in = output_idxes[:, :-1]
# tgt_y [batch_size, seq_len - 1] 去掉开头 的 SOS
output_idxes_out = output_idxes[:, 1:]
# tgt_mask [batch_size, seq_len-1, seq_len-1]
output_mask = tokenizer.make_std_mask(output_idxes_in, tokenizer.output_word2idx.get("<PAD>"))
# 记录生成的有效字符
ntokens = (output_idxes_out != tokenizer.output_word2idx.get("<PAD>")).data.sum()
# src, src_mask, tgt, tgt_mask, tgt_y, ntokens
return input_idxes, input_mask, output_idxes_in, output_mask, output_idxes_out, ntokens
代码理解:
第一步:提取输入数据、输入数据长度、输出数据、输出数据长度
input_sentences, input_sentence_lens, output_sentences, output_sentence_lens = zip(
*batch
)
batch
:一个包含多个样本的列表;zip(*batch)
:将batch
中的每个样本解包,分别提取出输入句子、输入句子长度、输出句子和输出句子长度。
示例理解:
input_sentences: output_sentences:
['I', 'm', 'sick', '.'] ['我', '病' , '了', '。']
['I', 'm', 'tall', '.'] ['我', '个子' , '高', '。']
['Leave', 'me', '.'] ['让', '我', '一个人', '呆','会', '儿','。']
input_sentence_lens: output_sentence_lens:
[4, 4, 3] [4, 4, 7]
第二步:对数据进行填充
# 输入的最大长度为4,所以input_idxes填充为
['I', 'm', 'sick', '.']
['I', 'm', 'tall', '.']
['Leave', 'me', '.', '<PAD>']
# 输出的最大长度为7,所以output_idxes填充为
['<SOS>', '我', '病' , '了', '。', '<EOS>', '<PAD>', '<PAD>', '<PAD>']
['<SOS>', '我', '个子' , '高', '。', '<EOS>', '<PAD>', '<PAD>', '<PAD>']
['<SOS>', '让', '我', '一个人', '呆', '会', '儿', '。', '<EOS>']
第三步:生成input的mask
input_mask = (input_idxes != tokenizer.input_word2idx.get("<PAD>")).unsqueeze(-2)
- 输入位置不为
<PAD>
的位置,值为1,否则为0,从而形成mask。
第四步:生成错位的output
# tgt [batch_size, seq_len - 1] 去掉最后一个
output_idxes_in = output_idxes[:, :-1]
# 例如:['<SOS>', '让', '我', '一个人', '呆', '会', '儿', '。']
# tgt_y [batch_size, seq_len - 1] 去掉开头 的 SOS
output_idxes_out = output_idxes[:, 1:]
# 例如:['让', '我', '一个人', '呆', '会', '儿', '。', '<EOS>']
第五步:生成output的mask
output_mask = tokenizer.make_std_mask(output_idxes_in, tokenizer.output_word2idx.get("<PAD>"))
def make_std_mask(cls, tgt, pad):
"Create a mask to hide padding and future words."
tgt_mask = (tgt != pad).unsqueeze(-2)
tgt_mask = tgt_mask & Tokenizer.subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
return tgt_mask
代码理解:
- 因为decoder的掩码多头注意力(mask MultiHeadAttention),既要屏蔽无效的PAD,同时还要屏蔽未来词。
- 所以tgt_mask是由
tgt_mask & Tokenizer.subsequent_mask
两部分按位与运算,即两者都为1才是有效的,如果有一个为0,则对应数据被遮挡。 Tokenizer.subsequent_mask
是生成一个三角矩阵,如下图:
开始训练 run_epoch()
def run_epoch(
data_iter,
model,
loss_compute,
optimizer,
scheduler,
mode="train",
accum_iter=1,
train_state=TrainState(),
device="cpu"
):
"""
Train a single epoch
"""
start = time.time()
total_tokens = 0
total_loss = 0
tokens = 0
n_accum = 0
for i, (src, src_mask, tgt, tgt_mask, tgt_y, ntokens) in enumerate(data_iter):
#
src = src.to(device=device)
tgt = tgt.to(device=device)
tgt_y = tgt_y.to(device=device)
# src = src.to(device=device)
out = model.forward(src, tgt, src_mask, tgt_mask)
loss, loss_node = loss_compute(out, tgt_y, ntokens)
# loss_node = loss_node / accum_iter
if mode == "train" or mode == "train+log":
loss_node.backward()
train_state.step += 1
train_state.samples += src.shape[0]
train_state.tokens += ntokens
if i % accum_iter == 0: