算法手撕面经系列(1)--手撕多头注意力机制

多头注意力机制

 一个简单的多头注意力模块可以分解为以下几个步骤:

  1. 先不分多头,对输入张量分别做变换,得到 Q , K , V Q,K,V Q,K,V
  2. 对得到的 Q , K , V Q,K,V Q,K,V按头的个数进行split;
  3. Q , K Q,K Q,K计算向量点积
  4. 考虑是否要添因果mask
  5. 利softmax计算注意力得分矩阵atten
  6. 对注意力得分矩阵施加Dropout
  7. 将atten矩阵和 V V V矩阵相乘
  8. 再过一道最终的输出变换

代码

 给出一个 d k = d v = d m o d e l d_k=d_v=d_{model} dk=dv=dmodel的多头注意力实现如下:


class MHA(nn.Module):
    def __init__(self,C_in,dmodel,num_head=8,p_drop=0.2):
        super(MHA, self).__init__()

        self.QW=nn.Linear(C_in,dmodel)
        self.KW=nn.Linear(C_in,dmodel)
        self.VW=nn.Linear(C_in,dmodel)

        self.dp=nn.Dropout(p_drop)

        self.W_concat=nn.Linear(dmodel,dmodel)

        self.n_head=num_head
        self.p_drop=p_drop
        self.depth=dmodel//num_head

    def forward(self,X,casual=True):
        B,L,C=X.shape
        Q=self.QW(X)
        K=self.KW(X)
        V=self.VW(X)

        Q=Q.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)
        K=K.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)
        V=V.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)

        atten=Q.matmul(K.transpose(2,3))

        if casual:
            mask=torch.triu(torch.ones(L,L))
            atten=torch.where(mask==1,atten,torch.ones_like(atten)*(-2**32+1))
        atten=torch.softmax(atten,dim=-1)

        atten=self.dp(atten)

        out=torch.matmul(atten,V)/self.depth**(1/2)

        out=out.permute(0,2,1,3).reshape(B,L,-1)
        out=self.W_concat(out)

        return out


if __name__=="__main__":
    input=torch.rand(10,5,3)
    model=MHA(3,64,4)
    res=model(input)

回答: 在Transformer模型中,输入是通过Encoder和Decoder进行处理的。首先,在Encoder中,我们将输入的序列通过词嵌入(embedding)层进行编码,得到一个表示序列的向量。然后,输入的序列经过一系列的Encoder层,每个层都包含多头自注意力(self-attention)机制和前馈神经网络(feed-forward neural network)。在每个Encoder层中,我们使用Masked Multi-Head Attention机制来学习输入序列的依赖关系,并且添加一个残差连接和层归一化操作。最后,Encoder的输出是上面提到的最后一层的输出。 接下来,在Decoder中,我们使用相似的过程对输出序列进行处理。首先,输出序列通过词嵌入层进行编码,然后经过一系列的Decoder层。在每个Decoder层中,我们使用Multi-Head Attention机制来学习输入序列和输出序列之间的依赖关系,并且添加一个残差连接和层归一化操作。然后,通过一个前馈神经网络进行处理。Decoder的输出是通过和Encoder的输出进行注意力计算得到的,并且最终通过线性变换得到预测的输出序列。 总结起来,Transformer模型的输入是通过Encoder进行编码,然后通过一系列的Encoder层进行处理。输出是通过Decoder进行处理,最终得到预测的输出序列。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [【Transformer】Transformer输入输出细节以及代码实现(pytorch)](https://blog.csdn.net/wl1780852311/article/details/121033915)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值