Pytorch实现dot/mlp attention

国庆节疯玩了7天,感觉整个人已经成为一条咸鱼了,在上班前一晚把attention机制复习了一下,就当是收收心了(感到羞愧)。

首先实现一些函数,称为attention_utils.py

import numpy as np
import torch
import torch.nn.functional as F


def create_src_lengths_mask(batch_size, src_lengths):
    '''
    生成布尔掩码以防止注意力超出source的末尾
    :param batch_size: int
    :param src_lengths: [batch_size] 每个句子的实际长度
    :return: [batch_size, max_src_len]
    '''
    max_src_len = src_lengths.max()
    # [1, max_src_len]
    src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)
    # [batch_size, max_src_len]
    src_indices = src_indices.expand(batch_size, max_src_len)
    src_lengths = src_lengths.unsqueeze(1).expand(batch_size, max_src_len)
    # 小于实际长度的为1,大于的为0,detach截断反向梯度传播
    return (src_indices < src_lengths).int().detach()


def masked_softmax(scores, src_lengths, src_length_masking=True):
    '''
    先生成mask,然后再进行softmax。
    '''
    if src_length_masking:
        batch_size, max_src_len = scores.size()
        # compute masks
        src_mask = create_src_lengths_mask(batch_size, src_lengths)
        # Fill pad positions with -inf
        scores = scores.masked_fill(src_mask == 0, -np.inf)

    # 转换为float16,然后再次转换回来以防止loss爆炸
    return F.softmax(scores.float(), dim=-1).type_as(scores)

然后实现一个无attention的基础类,base_attention.py

import 
  • 2
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值