小黑维度分析与逐行调试:MaskedEdgeAttention

输入:

features:[num_seqs,batch_size,2*D_e]

lengths:[batch_size]

edge_ind:batch_size个[(a,b)…]

输出:

scores:[batch_size,max_nodes,num_seqs(node_number)]

代码demo:

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import sys
sys.path.append('.')
import torch
import numpy as np
from model import SimpleAttention,MatchingAttention,Attention

class MaskedEdgeAttention(nn.Module):
    
    def __init__(self,input_dim,max_seq_len,no_cuda):
        super(MaskedEdgeAttention,self).__init__()
        
        self.input_dim = input_dim
        self.max_seq_len = max_seq_len
        self.scalar = nn.Linear(self.input_dim,self.max_seq_len,bias = False)
        self.matchatt = MatchingAttention(self.input_dim,self.input_dim,att_type = 'general2')
        self.simpleatt = SimpleAttention(self.input_dim)
        self.att = Attention(self.input_dim,score_function = 'mlp')
        self.no_cuda = no_cuda
        
    def forward(self,M,lengths,edge_ind):
        # M: [num_seqs,batch_size,2*D_e]
        # lengths: [batch_size]
        # edge_ind: edge_ind:batch_size个[(a,b)....]
        
        attn_type = 'attn1'
        
        if attn_type == 'attn1':
            # scale:[num_seqs,batch_size,max_nodes]
            scale = self.scalar(M)
            # alpha:[batch_size,max_nodes,num_seqs]
            alpha = F.softmax(scale,dim = 0).permute(1,2,0)
            
            if not self.no_cuda:
                # mask:[batch_size,max_nodes,num_seqs]
                mask = Variable(torch.ones(alpha.size()) * 1e-10).detach().cuda()
                # mask_copy:[batch_size,max_nodes,num_seqs]
                mask_copy = Variable(torch.zeros(alpha.size())).detach().cuda()
            else:
                mask = Variable(torch.ones(alpha.size()) * 1e-10).detach()
                mask_copy = Variable(torch.zeros(alpha.size())).detach()
            edge_ind_ = []
            for i,j in enumerate(edge_ind):
                for x in j:
                    edge_ind_.append([i,x[0],x[1]])
            # edge_ind_:[3,num_edges]  [[i,i,i,i,i.....],[x[0],x[0].....],[x[1],x[1].....]]
            edge_ind_ = np.array(edge_ind_).transpose()
            # mask:[batch_size,max_nodes,num_seqs]
            mask[edge_ind_] = 1
            # mask_copy:[batch_size,max_nodes,num_seqs]
            mask_copy[edge_ind_] = 1
            # masked_alpha:[batch_size,max_nodes,num_seqs]
            masked_alpha = alpha * mask
            # _sums:[batch_size,max_nodes,1]
            _sums = masked_alpha.sum(-1,keepdim = True)
            # scores:[batch_size,max_nodes,num_seqs]
            scores = masked_alpha.div(_sums) * mask_copy
            return scores
num_seqs = 20
D_e = 100
dim = 2 * D_e 
num_edges = 55
attn = MaskedEdgeAttention(dim,num_seqs,no_cuda = False).cuda()
batch_size = 32
M = torch.randn([num_seqs,batch_size,2*D_e]).cuda()
lengths = batch_size * [num_edges]
edge_ind = [torch.randint(0,num_seqs,(num_edges,2)).tolist() for _ in range(batch_size)]
print(attn(M,lengths,edge_ind).shape)

输出:

torch.Size([32, 20, 20])

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值