《Attention is all you need》Pytorch实现

《Attention is all you need》Pytorch实现代码

Self Attention

示意图

 (left) Scaled Dot-Product Attention. (right) Multi-Head Attention

代码实现

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
				"""
				
				:param embed_size: int
				:param heads: int
				
				        """
				super(SelfAttention, self).__init__()
				
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads."

        # Embedding Layer
        self.values = nn.Linear(self.embed_size, self.head_dim, bias=False)
        self.keys = nn.Linear(self.embed_size, self.head_dim, bias=False)
        self.queries = nn.Linear(self.embed_size, self.head_dim, bias=False)

        # Out Layer
        self.fc_out = nn.Linear(self.head_dim * heads, embed_size)

    def forward(self, values, keys, queries, mask):
				"""
				
				:param values:  (N,value_len,heads,head_dim)
				:param keys:    (N,key_len,heads,head_dim)
				:param queries: (N,query_len,heads,head_dim)
				:param mask:    (N,heads,query_len,key_len)
				:return out:    (N,query_len,heads,head_dim)
				
				        """
				N = queries.shape[0]
        values_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, values_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # attention shape: (N,heads,query_len,key_len)
        attention = nn.Softmax(dim=3)(energy / (self.embed_size ** (1 / 2)))

        # value_len always equals key_len
        out = torch.einsum("nhql,nlhd->nqhd", attention, values).reshape(N,query_len,self.heads*self.head_dim)
        # out shape: (N,query_len,heads,head_dim)
        out = self.fc_out(out)

        return out

注意事项


energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值