MultiHeadAttension源码解析——batch_first参数含义

前言

简单介绍batch_first参数的含义和相关概念。

1. 问题描述

Pytorch的多头注意力(MultiHeadAttension)代码中,有一个batch_first参数,在传递参数的时候必须注意。

def forward(self, query: Tensor, key: Tensor, value: Tensor, 
                key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, 
                attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
      
    if self.batch_first:
        query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

官方文档对batch_first的解释如下。

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

简单翻译一下:如果设置为True,输入和输出张量按照(batch,seq,feature)的顺序提供。默认值是False。按照(seq,batch,feature)的顺序。

通过查看源码发现,如果按照(batch,seq,feature)的顺序传入参数,并且batch_first设置为True,那么会自动转换成(seq,batch,feature)的顺序,输出结果的时候,再转换回来。

if self.batch_first:
    query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

转换代码

# 参数1和0意味着交换第一和第二个索引,也就是batch和seq。
transpose(1, 0)

2. 相关概念

batch、seq和feature

什么是batch:批量大小,就是一次传入的序列(句子)的数量。

什么是seq:序列长度,即单词数量。

什么是feature:特征长度,每个单词向量(Embedding)的长度。

也记为N(批量)、T(序列)、C(特征)。

为什么默认不是批量在前呢?

根据这篇文章https://zhuanlan.zhihu.com/p/32103001的解释:

为了便于并行计算,cuDNN中的RNN模型提供的API就是batch_size在第二维度。

虽然上文是按照RNN来解释的,但是应该对注意力模型也适用。

至于cuDNN这样排序的原因,是因为batch first=True意味着模型的输入(一个Tensor)在内存中存储时,先存储第一个sequence,再存储第二个,而如果是seq放在前面,模型的输入在内存中,先存储所有序列的第一个单元,然后是第二个单元。

  • 6
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值