import torch
import torch.nn as nn
import torch.nn.functional as F
class MatchingAttention(nn.Module):
def __init__(self,mem_dim,cand_dim,alpha_dim = None,att_type = 'general'):
super(MatchingAttention,self).__init__()
assert att_type != 'concat' or alpha_dim != None
assert att_type != 'dot' or mem_dim == cand_dim
self.mem_dim = mem_dim # D_g
self.cand_dim = cand_dim # D_m
self.att_type = att_type
if att_type == 'general':
self.transform = nn.Linear(cand_dim,mem_dim,bias = False)
if att_type == 'general2':
self.transform = nn.Linear(cand_dim,mem_dim,bias = True)
elif att_type == 'concat':
self.transform = nn.Linear(cand_dim + mem_dim,alpha_dim,bias = False)
self.vector_prod = nn.Linear(alpha_dim,1,bias = False)
def forward(self,M,x,mask = None):
"""
M -> (seq_len, batch, mem_dim)
x -> (batch, cand_dim)
mask -> (batch, seq_len)
"""
# M:[t-1,batch_size,D_g]
# x:[batch_size,D_m]
if type(mask) == type(None):
# mask:[batch_size,t-1]
mask = torch.ones(M.size(1),M.size(0)).type(M.type())
if self.att_type == 'dot':
M_ = M.permute(1,2,0) # [batch_size,D_g,t-1]
x_ = x.unsqueeze(1) # [batch_size,1,D_m]
apha = F.softmax(torch.bmm(x_,M_),dim = 2) # [batch_size,1,t-1]
elif self.att_type == 'general':
M_ = M.permute(1,2,0) # [batch_size,D_g,t-1]
x_ = self.transform(x).unsqueeze(1) # [batch_size,1,D_g]
alpha = F.softmax(torch.bmm(x_,M_),dim = 2) # [batch_size,1,t-1]
elif self.att_type == 'general2':
M_ = M.permute(1,2,0) # [batch_size,D_g,t-1]
x_ = self.transform(x).unsqueeze(1) # [batch_size,1,D_g]
# alpha_:[batch_size,1,t-1]
alpha_ = F.softmax((torch.bmm(x_,M_))*mask.unsqueeze(1),dim = 2)
# alpha_masked:[batch_size,1,t-1]
alpha_masked = alpha_ * mask.unsqueeze(1)
# [batch_size,1,1]
alpha_sum = torch.sum(alpha_masked,dim = 2,keepdim = True)
# [batch_size,1,t-1]
alpha = alpha_masked / alpha_sum
else:
M_ = M.transpose(0,1) # [batch_size,t-1,D_g]
x_ = x.unsqueeze(1).expand(-1,M.size()[0],-1) # [batch_size,t-1,D_m]
M_x_ = torch.cat([M_,x_],2) # [batch_size,t-1,D_g + D_m]
mx_a = F.tanh(self.transform(M_x_)) # [batch_size,t-1,alpha_dim]
alpha = F.softmax(self.vector_prod(mx_a),1).transpose(1,2) # [batch_size,1,t-1]
attn_pool = torch.bmm(alpha,M.transpose(0,1))[:,0,:] # [batch_size,D_g]
return attn_pool,alpha
mem_dim = 100
cand_dim = 150
attn = MatchingAttention(mem_dim,cand_dim)
# t-1:5
t_1 = 5
batch_size = 8
M = torch.randn([t_1,batch_size,mem_dim])
x = torch.randn([batch_size,cand_dim])
attn_pool,alpha = attn(M,x)
print('attn_pool:',attn_pool.shape)
print('alpha:',alpha.shape)
输出:
attn_pool: torch.Size([8, 100])
alpha: torch.Size([8, 1, 5])