一文搞懂Transformer-decoder

目录

一、解码器(Decoder)

二、掩码多头注意力

三、预测(Token Prediction)

四、decoder原理代码复现


一、解码器(Decoder)

        我们看一下到目前为止我们已经完成了哪些内容,以及我们还需要完成哪些内容: 

        我们不会计算整个解码器,因为它的大部分已经在编码器中完成了类似的计算,详细计算解码器只会使文章变得冗长,因为重复的步骤很多;下面,我们会更多关注解码器输入和输出的计算。

        解码器有两个输入,一个输入来自编码器,顶层编码器的输出转换为一组注意力向量K和V;这些向量将在每个解码器的“交叉注意力”层中使用,用于帮助解码器集中注意力于输入序列中的适当位置;第二个输入是预测文本,假设我们输入给编码器的是 "用简单语言讲解Transformer",解码器的输入是预测的文本"太棒了"。

 

        以上动图展示了Transformer解码器训练过程在机器翻译任务中的运用。

        接下来的步骤会重复这个过程,直到达到一个特殊的结束符,表示Transformer解码器已经完成了输出;每个步骤的输出都会在下一个时间步骤中被传递给底层解码器,解码器们会像编码器一样层层传递解码结果,并且,就像我们对编码器输入所做的那样,我们对这些解码器输入进行嵌入和添加位置编码,以指示每个单词的位置。

        以上动图展示了Transformer解码器训练过程在机器翻译任务中的运用。

        但是训练的目标输入文本需要遵循一种标准的token封装方式,这样Transformer就知道从哪里开始和结束。

 

        在这里,引入了两个新的token,分别是<|im_start|>和<|im_end|>,此外,解码器一次只能接受一个token作为输入,也就是说,<|im_start|>会被作为一个输入,而"太”就是下一个预测token。 

        正如我们所知,这些词嵌入是有随机权重值的,这些权重将在训练过程中进行更新。

        如下图所示,以之前在编码器部分计算的相同方式计算剩余的块。

        在深入了解之前,我们需要通过一个简单的例子来理解什么是掩码多头注意力(masked multi-head attention)的概念。

二、掩码多头注意力

        在Transformer中,掩码多头注意力就像是模型用来聚焦句子不同部分的灯光,它的特殊之处在于,它不允许模型查看句子中后面的单词,避免了作弊;这有助于模型逐步理解和生成句子,在对话或将单词翻译成其他语言等任务中尤为重要。

        我们以上面的目标输入(”太棒了“)矩阵为例,其中每一行代表序列中的一个位置,每一列代表一个特征,如下所示为7 * 6的矩阵(d_model为6)。 

现在,让我们来了解具有两个头的掩码多头注意力组件的构成:

  1. 线性投影(Q,K,V):假设每个头的线性投影为 Head 1: Wq1,Wk1,Wv1 和 Head 2: Wq2,Wk2,Wv2。
  2. 计算注意力分数:对于每个头,通过Q和K的点积计算注意力分数,并应用掩码以防止关注未来位置。
  3. 应用Softmax函数:将注意力分数应用于softmax函数,获得注意力权重。
  4. 加权求和(V):将注意力权重与值相乘,获得每个头的加权和。
  5. 加法和线性变换:将两个头的输出连接起来,并应用线性变换。

进行如下简化的计算:

 

        加法和线性变换步骤将两个注意力头的输出合并为一组信息,这一步骤有助于从多个角度捕捉输入数据的不同方面,并提供一个更丰富的表示,供模型进一步处理使用。

三、预测(Token Prediction)

        解码器最后一个加法和归一化层的输出矩阵必须具有与输入矩阵相同的行数,而列数可以是任意的。在这里,我们使用6列。

        为了将解码器最后一个加法和归一化层的结果矩阵与一个线性层匹配,必须将其展平,以求得语料库中每个唯一Token的预测概率。 

展平的层将通过一个线性层来计算语料库中每个唯一Token的逻辑值(分数)。

 

一旦我们获得逻辑值,我们可以使用softmax函数对它们进行归一化,并找到概率最高的Token。

 

 

并且,用户可以用temperature参数来控制softmax函数的输出分布,以增加或减少随机性。

根据我们的计算,解码器预测的单词是"太",它的Token是8192和103。

 

 

        这个预测的单词"太"将被视为解码器的输入单词,并且这个过程会一直持续,直到预测到<|im_end|>标记。

四、decoder原理代码复现

https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoderLayer.html 

dim=6,head=1

import torch

encoder_batch_size = 1
encoder_seq = 1
dim = 6
encoder_output = torch.rand(encoder_batch_size,encoder_seq,dim)

decoder_batch_size = 1
decoder_seq = 1
decoder_heads = 1
decoder_input = torch.rand(decoder_batch_size,decoder_seq,dim)

decoder_layer = torch.nn.TransformerDecoderLayer(dim,decoder_heads,dropout=0.0,batch_first=True)
out = decoder_layer(decoder_input,encoder_output)
print(out)

def my_scaled_dot_product(query,key,value):
    qk_T = torch.mm(query,key.T)
    qk_T_scale = qk_T / torch.sqrt(torch.tensor(value.shape[1]))
    qk_exp = torch.exp(qk_T_scale)
    qk_exp_sum = torch.sum(qk_exp,dim=1,keepdim=True)
    qk_softmax = qk_exp / qk_exp_sum
    v_attn = torch.mm(qk_softmax,value)
    return v_attn,qk_softmax
# 第一次multi-head
first_in_proj_weight = decoder_layer.state_dict()['self_attn.in_proj_weight']
first_in_proj_bias = decoder_layer.state_dict()['self_attn.in_proj_bias']
first_out_proj_weight = decoder_layer.state_dict()['self_attn.out_proj.weight']
first_out_proj_bias = decoder_layer.state_dict()['self_attn.out_proj.bias']
first_batch_V_output = torch.empty(decoder_batch_size,decoder_seq,dim)
for i in range(decoder_batch_size):
    first_in_proj = torch.mm(decoder_input[i],first_in_proj_weight.T) + first_in_proj_bias
    Qs,Ks,Vs = torch.split(first_in_proj,dim,dim=-1)
    head_Vs = []
    for Q,K,V in zip(torch.split(Qs,dim//decoder_heads,dim=-1),torch.split(Ks,dim//decoder_heads,dim=-1),torch.split(Vs,dim//decoder_heads,dim=-1)):
        head_v,_ = my_scaled_dot_product(Q,K,V)
        head_Vs.append(head_v)
    V_cat = torch.cat(head_Vs,dim=-1)
    V_ouput = torch.mm(V_cat,first_out_proj_weight.T) + first_out_proj_bias
    first_batch_V_output[i] = V_ouput
# 第一次加
first_Add = decoder_input + first_batch_V_output
# 第一次layer_norm
norm1_mean = torch.mean(first_Add,dim=-1,keepdim=True)
norm1_std = torch.sqrt(torch.var(first_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm1_weight = decoder_layer.state_dict()['norm1.weight']
norm1_bias = decoder_layer.state_dict()['norm1.bias']
norm1 = ((first_Add - norm1_mean)/norm1_std) * norm1_weight + norm1_bias
# 第二次multi-head
second_in_proj_weight = decoder_layer.state_dict()['multihead_attn.in_proj_weight']
second_in_proj_bias = decoder_layer.state_dict()['multihead_attn.in_proj_bias']
second_out_proj_weight = decoder_layer.state_dict()['multihead_attn.out_proj.weight']
second_out_proj_bias = decoder_layer.state_dict()['multihead_attn.out_proj.bias']
second_batch_V_output = torch.empty(decoder_batch_size,decoder_seq,dim)
for i in range(decoder_batch_size):
    Qs_weight,Ks_weight,Vs_weight = torch.split(second_in_proj_weight.T,dim,dim=-1)
    Qs_bias,Ks_bias,Vs_bias = torch.split(second_in_proj_bias,dim,dim=-1)
    Qs = torch.mm(norm1[i],Qs_weight) + Qs_bias
    Ks = torch.mm(encoder_output[i],Ks_weight) + Ks_bias
    Vs = torch.mm(encoder_output[i],Vs_weight) + Vs_bias
    head_Vs = []
    for Q,K,V in zip(torch.split(Qs,dim//decoder_heads,dim=-1),torch.split(Ks,dim//decoder_heads,dim=-1),torch.split(Vs,dim//decoder_heads,dim=-1)):
        head_v,_ = my_scaled_dot_product(Q,K,V)
        head_Vs.append(head_v)
    V_cat = torch.cat(head_Vs,dim=-1)
    V_ouput = torch.mm(V_cat,second_out_proj_weight.T) + second_out_proj_bias
    second_batch_V_output[i] = V_ouput
# 第二次加
second_Add = norm1 + second_batch_V_output
# 第二次layer_norm
norm2_mean = torch.mean(second_Add,dim=-1,keepdim=True)
norm2_std = torch.sqrt(torch.var(second_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm2_weight = decoder_layer.state_dict()['norm2.weight']
norm2_bias = decoder_layer.state_dict()['norm2.bias']
norm2 = ((second_Add - norm2_mean)/norm2_std) * norm2_weight + norm2_bias
# feed forward
linear1_weight = decoder_layer.state_dict()['linear1.weight']
linear1_bias = decoder_layer.state_dict()['linear1.bias']
linear2_weight = decoder_layer.state_dict()['linear2.weight']
linear2_bias = decoder_layer.state_dict()['linear2.bias']
linear1 = torch.matmul(norm2,linear1_weight.T) + linear1_bias
linear1_relu = torch.nn.functional.relu(linear1)
linear2 = torch.matmul(linear1_relu,linear2_weight.T) + linear2_bias
# 第三次加
third_Add = norm2 + linear2
# 第三次layer_norm
norm3_mean = torch.mean(third_Add,dim=-1,keepdim=True)
norm3_std = torch.sqrt(torch.var(third_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm3_weight = decoder_layer.state_dict()['norm3.weight']
norm3_bias = decoder_layer.state_dict()['norm3.bias']
norm3 = ((third_Add - norm3_mean)/norm3_std) * norm3_weight + norm3_bias
print(norm3)

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值