前言
简单介绍batch_first
参数的含义和相关概念。
1. 问题描述
Pytorch的多头注意力(MultiHeadAttension)代码中,有一个batch_first
参数,在传递参数的时候必须注意。
def forward(self, query: Tensor, key: Tensor, value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
if self.batch_first:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
略
官方文档对batch_first的解释如下。
batch_first – If
True
, then the input and output tensors are provided as (batch, seq, feature). Default:False
(seq, batch, feature).
简单翻译一下:如果设置为True,输入和输出张量按照(batch,seq,feature)的顺序提供。默认值是False。按照(seq,batch,feature)的顺序。
通过查看源码发现,如果按照(batch,seq,feature)的顺序传入参数,并且batch_first设置为True,那么会自动转换成(seq,batch,feature)的顺序,输出结果的时候,再转换回来。
if self.batch_first:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
转换代码
# 参数1和0意味着交换第一和第二个索引,也就是batch和seq。
transpose(1, 0)
2. 相关概念
batch、seq和feature
什么是batch:批量大小,就是一次传入的序列(句子)的数量。
什么是seq:序列长度,即单词数量。
什么是feature:特征长度,每个单词向量(Embedding)的长度。
也记为N(批量)、T(序列)、C(特征)。
为什么默认不是批量在前呢?
根据这篇文章https://zhuanlan.zhihu.com/p/32103001的解释:
为了便于并行计算,cuDNN
中的RNN
模型提供的API就是batch_size
在第二维度。
虽然上文是按照RNN来解释的,但是应该对注意力模型也适用。
至于cuDNN这样排序的原因,是因为batch first=True意味着模型的输入(一个Tensor)在内存中存储时,先存储第一个sequence,再存储第二个,而如果是seq放在前面,模型的输入在内存中,先存储所有序列的第一个单元,然后是第二个单元。