1.MaskedNLLLoss
import torch
import torch.nn as nn
class MaskedNLLLoss(nn.Module):
def __init__(self,weight = None):
super(MaskedNLLLoss,self).__init__()
self.weight = weight
self.loss = nn.NLLLoss(weight = weight,reduction = 'sum')
def forward(self,pred,target,mask):
mask_ = mask.view(-1,1) # [batch_size * seq_len,1]
if type(self.weight) == type(None):
loss = self.loss(pred * mask_,target) / torch.sum(mask)
else:
loss = self.loss(pred * mask_,target) / torch.sum(self.weight[target] * mask_.squeeze())
return loss
predict = torch.randn([5,7,10])
target = torch.ones([5,7]).long()
mask = torch.ones([5,7]).long()
loss = MaskedNLLLoss()
loss(predict.view(-1,10),target.view(-1),mask.view(-1))
输出:
tensor(-0.0985)
2.SimpleAttention
class SimpleAttention(nn.Module):
def __init__(self,input_dim):
super(SimpleAttention,self).__init__()
self.input_dim = input_dim
self.scalar = nn.Linear(self.input_dim,1,bias = False)
def forward(self,M,x = None):
# M:[t-1,batch_size,D_g]
# x:[batch_size,D_m]
scale = self.scalar(M) # [t-1,batch_size,1]
alpha = F.softmax(scale,dim = 0).permute(1,2,0) # [batch_size,1,t-1]
attn_pool = torch.bmm(alpha,M.transpose(0,1))[:,0,:] # [batch_size,D_g]
return attn_pool,alpha
D_g = 100
attn = SimpleAttention(D_g)
t_1 = 5
batch_size = 8
M = torch.randn([t_1,batch_size,D_g])
attn_pool,alpha = attn(M)
print('attn_pool:',attn_pool.shape)
print('alpha:',alpha.shape)
输出:
attn_pool: torch.Size([8, 100])
alpha: torch.Size([8, 1, 5])