输入:
features:[num_seqs,batch_size,2*D_e]
lengths:[batch_size]
edge_ind:batch_size个[(a,b)…]
输出:
scores:[batch_size,max_nodes,num_seqs(node_number)]
代码demo:
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import sys
sys.path.append('.')
import torch
import numpy as np
from model import SimpleAttention,MatchingAttention,Attention
class MaskedEdgeAttention(nn.Module):
def __init__(self,input_dim,max_seq_len,no_cuda):
super(MaskedEdgeAttention,self).__init__()
self.input_dim = input_dim
self.max_seq_len = max_seq_len
self.scalar = nn.Linear(self.input_dim,self.max_seq_len,bias = False)
self.matchatt = MatchingAttention(self.input_dim,self.input_dim,att_type = 'general2')
self.simpleatt = SimpleAttention(self.input_dim)
self.att = Attention(self.input_dim,score_function = 'mlp')
self.no_cuda = no_cuda
def forward(self,M,lengths,edge_ind):
# M: [num_seqs,batch_size,2*D_e]
# lengths: [batch_size]
# edge_ind: edge_ind:batch_size个[(a,b)....]
attn_type = 'attn1'
if attn_type == 'attn1':
# scale:[num_seqs,batch_size,max_nodes]
scale = self.scalar(M)
# alpha:[batch_size,max_nodes,num_seqs]
alpha = F.softmax(scale,dim = 0).permute(1,2,0)
if not self.no_cuda:
# mask:[batch_size,max_nodes,num_seqs]
mask = Variable(torch.ones(alpha.size()) * 1e-10).detach().cuda()
# mask_copy:[batch_size,max_nodes,num_seqs]
mask_copy = Variable(torch.zeros(alpha.size())).detach().cuda()
else:
mask = Variable(torch.ones(alpha.size()) * 1e-10).detach()
mask_copy = Variable(torch.zeros(alpha.size())).detach()
edge_ind_ = []
for i,j in enumerate(edge_ind):
for x in j:
edge_ind_.append([i,x[0],x[1]])
# edge_ind_:[3,num_edges] [[i,i,i,i,i.....],[x[0],x[0].....],[x[1],x[1].....]]
edge_ind_ = np.array(edge_ind_).transpose()
# mask:[batch_size,max_nodes,num_seqs]
mask[edge_ind_] = 1
# mask_copy:[batch_size,max_nodes,num_seqs]
mask_copy[edge_ind_] = 1
# masked_alpha:[batch_size,max_nodes,num_seqs]
masked_alpha = alpha * mask
# _sums:[batch_size,max_nodes,1]
_sums = masked_alpha.sum(-1,keepdim = True)
# scores:[batch_size,max_nodes,num_seqs]
scores = masked_alpha.div(_sums) * mask_copy
return scores
num_seqs = 20
D_e = 100
dim = 2 * D_e
num_edges = 55
attn = MaskedEdgeAttention(dim,num_seqs,no_cuda = False).cuda()
batch_size = 32
M = torch.randn([num_seqs,batch_size,2*D_e]).cuda()
lengths = batch_size * [num_edges]
edge_ind = [torch.randint(0,num_seqs,(num_edges,2)).tolist() for _ in range(batch_size)]
print(attn(M,lengths,edge_ind).shape)
输出:
torch.Size([32, 20, 20])