【课程总结】day20:Transformer源码深入理解之训练过程

前言

在上一章【课程总结】day19(下):Transformer源码深入理解总结中,我们对Transformer架构以及初始化部分做了梳理,本章我们将对Transformer训练过程进行代码分析理解。

训练流程

  • 训练过程主要由四个主要部分组成:
  • 第一部分:加载数据集。通过get_dataloader()加载数据集,这一过程与Seq2Seq类似,这里不再赘述。
  • 第二部分:数据对齐。调用collate_fn()函数,对数据集整理并对齐。
  • 第三部分:开始训练。调用run_epoch()函数,循环遍历dataloader中的数据集并进行训练。
  • 第四部分:前向传播。在训练过程中,调用EncoderDecoder的forward()函数进行前向传播。

代码分析理解

数据对齐 collate_fn()

源码如下:

def collate_fn(batch, tokenizer):

    input_sentences, input_sentence_lens, output_sentences, output_sentence_lens = zip(
        *batch
    )

    # 转索引【按本批量最大长度来填充】
    input_sentence_len = max(input_sentence_lens)
    input_idxes = []
    for input_sentence in input_sentences:
        input_idxes.append(tokenizer.encode_input(input_sentence, input_sentence_len))

    # 转索引【按本批量最大长度来填充】
    output_sentence_len = max(output_sentence_lens)
    output_idxes = []
    for output_sentence in output_sentences:
        output_idxes.append(
            tokenizer.encode_output(output_sentence, output_sentence_len)
        )
    # 转张量 [batch_size, seq_len]  src
    input_idxes = torch.LongTensor(input_idxes)
    # src_mask [batch_size, 1, seq_len]
    input_mask = (input_idxes != tokenizer.input_word2idx.get("<PAD>")).unsqueeze(-2)
    # tgt [batch_size, seq_len]
    output_idxes = torch.LongTensor(output_idxes)
    # tgt [batch_size, seq_len - 1] 去掉最后一个
    output_idxes_in = output_idxes[:, :-1]
    # tgt_y [batch_size, seq_len - 1] 去掉开头 的 SOS
    output_idxes_out = output_idxes[:, 1:]
    # tgt_mask [batch_size, seq_len-1, seq_len-1]
    output_mask = tokenizer.make_std_mask(output_idxes_in, tokenizer.output_word2idx.get("<PAD>"))
    # 记录生成的有效字符
    ntokens = (output_idxes_out != tokenizer.output_word2idx.get("<PAD>")).data.sum()
    # src, src_mask, tgt, tgt_mask, tgt_y, ntokens
    return input_idxes, input_mask, output_idxes_in, output_mask, output_idxes_out, ntokens

代码理解:
第一步:提取输入数据、输入数据长度、输出数据、输出数据长度

    input_sentences, input_sentence_lens, output_sentences, output_sentence_lens = zip(
        *batch
    )
  • batch:一个包含多个样本的列表;
  • zip(*batch):将 batch 中的每个样本解包,分别提取出输入句子、输入句子长度、输出句子和输出句子长度。

示例理解:

input_sentences:                   output_sentences:
['I', 'm', 'sick', '.']           ['我', '病' , '了', '。']
['I', 'm', 'tall', '.']           ['我', '个子' , '高', '。']
['Leave', 'me', '.']              ['让', '我', '一个人', '呆','会', '儿','。']

input_sentence_lens:               output_sentence_lens:
[4, 4, 3]                         [447]

第二步:对数据进行填充

# 输入的最大长度为4,所以input_idxes填充为
['I', 'm', 'sick', '.'] 
['I', 'm', 'tall', '.'] 
['Leave', 'me', '.', '<PAD>']   


# 输出的最大长度为7,所以output_idxes填充为
['<SOS>', '我', '病' , '了', '。', '<EOS>', '<PAD>', '<PAD>', '<PAD>']
['<SOS>', '我', '个子' , '高', '。', '<EOS>', '<PAD>', '<PAD>', '<PAD>']
['<SOS>', '让', '我', '一个人', '呆', '会', '儿', '。', '<EOS>']

第三步:生成input的mask

    input_mask = (input_idxes != tokenizer.input_word2idx.get("<PAD>")).unsqueeze(-2)
  • 输入位置不为<PAD>的位置,值为1,否则为0,从而形成mask。

第四步:生成错位的output

    # tgt [batch_size, seq_len - 1] 去掉最后一个
    output_idxes_in = output_idxes[:, :-1]
    # 例如:['<SOS>', '让', '我', '一个人', '呆', '会', '儿', '。']

    # tgt_y [batch_size, seq_len - 1] 去掉开头 的 SOS
    output_idxes_out = output_idxes[:, 1:]
    # 例如:['让', '我', '一个人', '呆', '会', '儿', '。', '<EOS>']

第五步:生成output的mask

    output_mask = tokenizer.make_std_mask(output_idxes_in, tokenizer.output_word2idx.get("<PAD>"))
    def make_std_mask(cls, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Tokenizer.subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
        return tgt_mask

代码理解:

  • 因为decoder的掩码多头注意力(mask MultiHeadAttention),既要屏蔽无效的PAD,同时还要屏蔽未来词。
  • 所以tgt_mask是由tgt_mask & Tokenizer.subsequent_mask两部分按位与运算,即两者都为1才是有效的,如果有一个为0,则对应数据被遮挡。
  • Tokenizer.subsequent_mask 是生成一个三角矩阵,如下图:

开始训练 run_epoch()

def run_epoch(
        data_iter,
        model,
        loss_compute,
        optimizer,
        scheduler,
        mode="train",
        accum_iter=1,
        train_state=TrainState(),
        device="cpu"
):
    """
    Train a single epoch
    """
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    n_accum = 0
    for i, (src, src_mask, tgt, tgt_mask, tgt_y, ntokens) in enumerate(data_iter):
        #
        src = src.to(device=device)
        tgt = tgt.to(device=device)
        tgt_y = tgt_y.to(device=device)
        # src = src.to(device=device)
        out = model.forward(src, tgt, src_mask, tgt_mask)
        loss, loss_node = loss_compute(out, tgt_y, ntokens)
        # loss_node = loss_node / accum_iter
        if mode == "train" or mode == "train+log":
            loss_node.backward()
            train_state.step += 1
            train_state.samples += src.shape[0]
            train_state.tokens += ntokens
            if i % accum_iter == 0:
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值