小黑维度逐行分析与调试:SimpleAttention、MaskedNLLLoss

这篇博客介绍了两个深度学习中的关键概念:MaskedNLLLoss和SimpleAttention。MaskedNLLLoss是一种处理序列数据的损失函数,通过考虑掩码来避免无效位置的影响。SimpleAttention则展示了如何实现一个简单的注意力机制,用于加权聚合输入序列信息,其输出为加权后的池化向量和注意力权重分布。这两个组件常用于序列建模任务,如机器翻译和语音识别。
摘要由CSDN通过智能技术生成

1.MaskedNLLLoss

import torch
import torch.nn as nn
class MaskedNLLLoss(nn.Module):
    def __init__(self,weight = None):
        super(MaskedNLLLoss,self).__init__()
        self.weight = weight
        self.loss = nn.NLLLoss(weight = weight,reduction = 'sum')
    def forward(self,pred,target,mask):
        mask_ = mask.view(-1,1)    # [batch_size * seq_len,1]
        if type(self.weight) == type(None):
            loss = self.loss(pred * mask_,target) / torch.sum(mask)
        else:
            loss = self.loss(pred * mask_,target) / torch.sum(self.weight[target] * mask_.squeeze())
        return loss
predict = torch.randn([5,7,10])
target = torch.ones([5,7]).long()
mask = torch.ones([5,7]).long()
loss = MaskedNLLLoss()
loss(predict.view(-1,10),target.view(-1),mask.view(-1))

输出:

tensor(-0.0985)

2.SimpleAttention

class SimpleAttention(nn.Module):
    def __init__(self,input_dim):
        super(SimpleAttention,self).__init__()
        self.input_dim = input_dim
        self.scalar = nn.Linear(self.input_dim,1,bias = False)
    def forward(self,M,x = None):
        # M:[t-1,batch_size,D_g]
        # x:[batch_size,D_m]
        scale = self.scalar(M)   # [t-1,batch_size,1]
        alpha = F.softmax(scale,dim = 0).permute(1,2,0)    # [batch_size,1,t-1]
        attn_pool = torch.bmm(alpha,M.transpose(0,1))[:,0,:]    # [batch_size,D_g]
        return attn_pool,alpha
D_g = 100
attn = SimpleAttention(D_g)
t_1 = 5
batch_size = 8
M = torch.randn([t_1,batch_size,D_g])
attn_pool,alpha = attn(M)
print('attn_pool:',attn_pool.shape)
print('alpha:',alpha.shape)

输出:

attn_pool: torch.Size([8, 100])
alpha: torch.Size([8, 1, 5])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值