小黑维度逐行分析与调试:MatchingAttention

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
class MatchingAttention(nn.Module):
    def __init__(self,mem_dim,cand_dim,alpha_dim = None,att_type = 'general'):
        super(MatchingAttention,self).__init__()
        assert att_type != 'concat' or alpha_dim != None
        assert att_type != 'dot' or mem_dim == cand_dim
        self.mem_dim = mem_dim    # D_g
        self.cand_dim = cand_dim    # D_m
        self.att_type = att_type    
        if att_type == 'general':
            self.transform = nn.Linear(cand_dim,mem_dim,bias = False)
        if att_type == 'general2':
            self.transform = nn.Linear(cand_dim,mem_dim,bias = True)
        elif att_type == 'concat':
            self.transform = nn.Linear(cand_dim + mem_dim,alpha_dim,bias = False)
            self.vector_prod = nn.Linear(alpha_dim,1,bias = False)
    def forward(self,M,x,mask = None):
        """
        M -> (seq_len, batch, mem_dim)
        x -> (batch, cand_dim)
        mask -> (batch, seq_len)
        """
        # M:[t-1,batch_size,D_g]
        # x:[batch_size,D_m]
        if type(mask) == type(None):
            # mask:[batch_size,t-1]
            mask = torch.ones(M.size(1),M.size(0)).type(M.type())
        if self.att_type == 'dot':
            M_ = M.permute(1,2,0)    # [batch_size,D_g,t-1]
            x_ = x.unsqueeze(1)    # [batch_size,1,D_m]
            apha = F.softmax(torch.bmm(x_,M_),dim = 2)    # [batch_size,1,t-1]
        elif self.att_type == 'general':
            M_ = M.permute(1,2,0)    # [batch_size,D_g,t-1]
            x_ = self.transform(x).unsqueeze(1)    # [batch_size,1,D_g]
            alpha = F.softmax(torch.bmm(x_,M_),dim = 2)    # [batch_size,1,t-1]
        elif self.att_type == 'general2':
            M_ = M.permute(1,2,0)    # [batch_size,D_g,t-1]
            x_ = self.transform(x).unsqueeze(1)    # [batch_size,1,D_g]
            # alpha_:[batch_size,1,t-1]
            alpha_ = F.softmax((torch.bmm(x_,M_))*mask.unsqueeze(1),dim = 2)
            # alpha_masked:[batch_size,1,t-1]
            alpha_masked = alpha_ * mask.unsqueeze(1)
            # [batch_size,1,1]
            alpha_sum = torch.sum(alpha_masked,dim = 2,keepdim = True)
            # [batch_size,1,t-1]
            alpha = alpha_masked / alpha_sum
        else:
            M_ = M.transpose(0,1)    # [batch_size,t-1,D_g]
            x_ = x.unsqueeze(1).expand(-1,M.size()[0],-1)    # [batch_size,t-1,D_m]
            M_x_ = torch.cat([M_,x_],2)    # [batch_size,t-1,D_g + D_m]
            mx_a = F.tanh(self.transform(M_x_))    # [batch_size,t-1,alpha_dim]
            alpha = F.softmax(self.vector_prod(mx_a),1).transpose(1,2)    # [batch_size,1,t-1]
        attn_pool = torch.bmm(alpha,M.transpose(0,1))[:,0,:]    # [batch_size,D_g]
        return attn_pool,alpha
mem_dim = 100
cand_dim = 150
attn = MatchingAttention(mem_dim,cand_dim) 
# t-1:5
t_1 = 5
batch_size = 8

M = torch.randn([t_1,batch_size,mem_dim])
x = torch.randn([batch_size,cand_dim])
attn_pool,alpha = attn(M,x)
print('attn_pool:',attn_pool.shape)
print('alpha:',alpha.shape)

输出:

attn_pool: torch.Size([8, 100])
alpha: torch.Size([8, 1, 5])

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值