nn.MultiheadAttention详解 -- forward()中维度、计算方式

nn.MultiheadAttention

        该模块兼顾了 self-attention 和 cross-attention;

        是构成  nn.transformer 的核心算子;

        首先看其接口文档:

CLASStorch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, 
add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, 
device=None, dtype=None)

        参数解释

        embed_dim: 样本序列token的嵌入维度;

        num_heads: 并行注意力头的数量;

        kdim: 注意力中key的维度;

        vdim: 注意力中value的维度;

        forward() 参数:

forward(query, key, value, key_padding_mask=None, need_weights=True, 
attn_mask=None, average_attn_weights=True)

        参数解释:

        query: 含有batch时,其shape为(L,N,Eq), 这里的L表示target embedding_dim, N 表示 batch size, Eq = embed_dim;

        key: 含有batch时,其shape为(S,N,Ek), 这里的S表示source embedding_dim, N 表示 batch size, Ek = kdim;

        value: 含有batch时,其shape为(S,N,Ev), 这里的S表示source embedding_dim, N 表示 batch size, Ev = vdim;

        average_attn_weights : 返回的attention weights,默认返回各头的平均值;

        通过上述参数的解释,可以看到K, V是来自同一个来源,对应于transformer中的encoder -- shape对应于 source embedding_dim;

                                                Q 是另外一个来源,对应于transformer中的decoder -- shape对应于 target embedding_dim;

        也是cross-attention使用的;

        self-attention使用:

        在forward()中的, query、key 和 value 是相同维度的张量,即: embed_dim = kdim = vdim;

        forward()中的计算流程:

        首先初始化,

                

                  

                 

        之后使用

                 

                

                

        最后得到返回值,也就是average_attn_weights:

                

         对应的维度解释:

                N 即 batch size提到最前面,然后做后两维的转置;

                

                 

                对于V,N提前:

                 

                 则最后的维度为:

                

        通过上述计算就可以得到最终的attention weights output

 参考视频:[pytorch模型拓扑结构] nn.MultiheadAttention, init/forward, 及 query,key,value 的计算细节_哔哩哔哩_bilibili

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值