最近在尝试使用fairseq框架复现多语言机器翻译模型,但是对于其中batch sampler的生成过程还是有一点困惑,fairseq中的batch size与一般的batch size不同,是变长的,即根据指定的max token和句子的长度来确定当前batch中有多少句子,这导致每一个batch中的句子数量是不一致的。因此,本文旨在根据框架代码来探究一下fairseq是如何根据每个语言向的dataset和指定的max token/max sentence来确定每个batch的batch sampler的。
数据集准备
笔者这里采用TED TALKS
数据集作为实验数据集,选择了其中的四个语言向进行实验,分别是ar-en,he-en,it-en,de-en
,参考的论文为When and Why Are Pre-Trained Word Embeddings Useful for Neural Machine Translation?
,数据集地址如下:
https://github.com/neulab/word-embeddings-for-nmt
下载数据集并解压后,显示是各个语言向汇总的数据集,如下所示,
我们需要从中提取出我们需要的语言向,可以参考github仓库中的这个文件:
https://github.com/neulab/word-embeddings-for-nmt/blob/master/ted_reader.py
指定src_lang
和trg_lang
就能够提取出对应的语言对,注意ted data的路径需要指定正确,为了后续学习bpe分词,我们还需要把所有需要的语言向的训练数据进行合并并命名为train.all
文件,可以直接在ted_reader.py
中进行修改,加入save_file_all
方法并进行调用
def save_file_all(self, path_, split_type, data_type):
with open(path_, 'a') as fp:
for line in self.data[split_type][data_type]:
fp.write(line + '\n')
# save to train.all
obj.save_file_all(output_data_path + "/train.all", split_type='train', data_type='source')
obj.save_file_all(output_data_path + "/train.all", split_type='train', data_type='target')
数据集处理
学习bpe并应用
#!/usr/bin/env bash
# echo 'Cloning fairseq repository...'
# git clone git@github.com:facebookresearch/fairseq.git
bpe=bpe
tmp=TED
rm -r $bpe
mkdir -p $bpe
# learn bpe
python -u ../../fairseq/scripts/spm_train.py \
--input=$tmp/train.all \
--model_prefix=spm.bpe \
--vocab_size=70000 \
--character_coverage=1.0 \
--model_type=bpe \
--num_threads=45 \
--shuffle_input_sentence
# apply bpe in {}-en
for split in train valid test; do
# 具体使用时需要根据需要修改语言对
for src in ar he ru ko it ja zh es fr pt nl tr ro pl bg vi de fa hu; do
echo ${split} ${src}-en
python ../../fairseq/scripts/spm_encode.py \
--model spm.bpe.model \
--output_format=piece \
--inputs ${tmp}/${split}.${src}-en.${src} ${tmp}/${split}.${src}-en.en \
--outputs ${bpe}/${split}.${src}-en.bpe.${src} ${bpe}/${split}.${src}-en.bpe.en
done
done
二值化处理
#!/usr/bin/env bash
# create share dict
path=data-bin
rm -r $path
mkdir -p $path
# https://github.com/facebookresearch/fairseq/issues/2110#issue-614837309
cut -f1 spm.bpe.vocab | tail -n +4 | sed "s/$/ 100/g" > $path/dict.txt
#for lang in ar de es fa he it nl pl en; do
# cp $path/dict.txt $path/dict.${lang}.txt
#done
# binarize {}-en
for src in ar he ru ko it ja zh es fr pt nl tr ro pl bg vi de fa hu; do
echo ${src}-en
# 首先需要准备fairseq
fairseq-preprocess \
--source-lang $src --target-lang en \
--trainpref bpe/train.${src}-en.bpe \
--validpref bpe/valid.${src}-en.bpe \
--testpref bpe/test.${src}-en.bpe \
--destdir $path \
--srcdict $path/dict.txt \
--tgtdict $path/dict.txt
done
训练
训练脚本如下所示:
CUDA_VISIBLE_DEVICES=0 fairseq-train /data2/lypan/ted_data/word-embeddings-for-nmt-master/data-bin \
--max-epoch 100 \
--ddp-backend=legacy_ddp \
--task multilingual_translation --lang-pairs ar-en,he-en,it-en,de-en \
--arch multilingual_transformer_iwslt_de_en \
--share-decoders --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--label-smoothing 0.1 --criterion label_smoothed_cross_entropy \
--dropout 0.3 --weight-decay 0.0001 \
--save-dir /data2/lypan/fairseq/checkpoints \
--max-tokens 4000 \
--update-freq 8 \
--tensorboard-logdir /data2/lypan/fairseq/logs \
--save-interval 5 --keep-best-checkpoints 1
调试
由于直接在terminal调试太不现实,所以笔者就转战pycharm远程连接服务器进行调试。我们首先关注fairseq_task.py中的以下代码
# create mini-batches with given size constraints
batch_sampler = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
此处的indices是一个从0到数据集长度的索引数组,数据集长度指四种语言对中最长的语言对数据集长度(此处是ar-en
语言对数据最长,是214111)
然后跳转进入batch_by_size
方法,可以看到,其实还是没有对batch sampler进行生成,而是又跳转进入了data_utils里的batch_by_size
方法,
return data_utils.batch_by_size(
indices,
num_tokens_fn=self.num_tokens,
num_tokens_vec=num_tokens_vec,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple, fixed_shapes=fixed_shapes,
)
进入data_utils.py后,首先指定了max_tokens
、max_sentences
、bsz_mult
等属性,然后又跳转进入了batch_by_size_fn
方法(不要着急,这是最后一个了)
if fixed_shapes is None:
if num_tokens_vec is None:
return batch_by_size_fn(
indices,
num_tokens_fn,
max_tokens,
max_sentences,
bsz_mult,
)
我们需要先明确一下num_tokens_fn
是传入的哪个函数
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
return max(
dataset.num_tokens(self._map_index(key, index))
for key, dataset in self.datasets.items()
)
其作用是根据传入index来寻找当前四个语言对对应index的最长的一句话,举个栗子,
然后我们再回来看一下batch_by_size_fn
方法,这里为了计算速度,代码是使用cython写的,这也就意味着我们打断点是运行不到这里的,所以我们就只能硬读了,代码如下所示,
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_fn(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
# indices长度
cdef int32_t indices_len = indices.shape[0]
# 长度与indices相同的全0数组
cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
dtype=np.int64)
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
cdef int64_t pos
for pos in range(indices_len):
# 记录每个index的各个语言对数据的最大长度
num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
max_sentences, bsz_mult,)
这里面的num_tokens_vec
数组非常重要,它与indices
数组的长度一致(214111),其中每个位置记录的是各个语言对数据对应index的数据的最大长度,这样假如以后保证最大长度也在max token的限制内,那么其他语言对的数据也肯定在max token的限制内。
我们可以看到,它还调用了batch_by_size_vec
,这个方法就能返回根据数据集句子长度分割好的每个batch的batch sampler了,方法代码如下:
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_vec(
np.ndarray[int64_t, ndim=1] indices,
np.ndarray[int64_t, ndim=1] num_tokens_vec,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
if indices.shape[0] == 0:
return []
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
f"Sentences lengths should not exceed max_tokens={max_tokens}"
)
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int32_t, ndim=1] batches_ends = \
np.zeros(indices_len, dtype=np.int32)
cdef int32_t[:] batches_ends_view = batches_ends
cdef int64_t[:] num_tokens_view = num_tokens_vec
cdef int32_t pos = 0
cdef int32_t new_batch_end = 0
cdef int64_t new_batch_max_tokens = 0
cdef int32_t new_batch_sentences = 0
cdef int64_t new_batch_num_tokens = 0
cdef bool_t overflow = False
cdef bool_t size_matches_with_bsz_mult = False
cdef int32_t batches_count = 0
cdef int32_t batch_start = 0
cdef int64_t tail_max_tokens = 0
cdef int64_t batch_max_tokens = 0
for pos in range(indices_len):
# At every pos we keep stats about the last complete batch [batch_start:batch_end),
# and tail [batch_end:pos].
# 1) Every time when (batch + tail) forms a valid batch
# (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
# 2) When (batch+tail) violates max_tokens or max_sentences constraints
# we finalize running batch, and tail becomes a new batch.
# 3) There is a corner case when tail also violates constraints.
# In that situation [batch_end:pos-1] (tail without the current pos)
# gets added to the finalized batches, while [pos:pos] becomes a new tail.
#
# Important: For the sake of performance try to avoid using function calls within this loop.
tail_max_tokens = tail_max_tokens \
if tail_max_tokens > num_tokens_view[pos] \
else num_tokens_view[pos]
new_batch_end = pos + 1
new_batch_max_tokens = batch_max_tokens \
if batch_max_tokens > tail_max_tokens \
else tail_max_tokens
# 记录新批次的句子数量
new_batch_sentences = new_batch_end - batch_start
# 记录新批次的token数量,必须根据最长句子的token数量计算(pad)
new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens
overflow = (new_batch_sentences > max_sentences > 0 or
new_batch_num_tokens > max_tokens > 0)
size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
new_batch_sentences % bsz_mult == 0)
if overflow:
# 尾部序列的长度
tail_num_tokens = tail_max_tokens * \
(new_batch_end - batches_ends_view[batches_count])
# 判断尾部序列有没有超出限制
tail_overflow = tail_num_tokens > max_tokens > 0
# 如果尾部序列超出限制,则尾部序列单独成为一个batch
if tail_overflow:
batches_count += 1
batches_ends_view[batches_count] = pos
tail_max_tokens = num_tokens_view[pos]
# 如果尾部序列没有超出限制,但是该批和尾部序列超出限制,则该批被最终确定,尾部序列成为一个新批
batch_start = batches_ends_view[batches_count]
batches_count += 1
new_batch_max_tokens = tail_max_tokens
if overflow or size_matches_with_bsz_mult:
batches_ends_view[batches_count] = new_batch_end
batch_max_tokens = new_batch_max_tokens
tail_max_tokens = 0
if batches_ends_view[batches_count] != indices_len:
batches_count += 1
# Memory and time-efficient split
return np.split(indices, batches_ends[:batches_count])
该函数用于根据一些约束条件,如最大标记数、最大句子数和批量大小的倍数,将给定的索引列表分成若干批。
该函数的逻辑如下。
- 该函数遍历索引列表中的每个位置。
- 对于每个位置,它都会跟踪最后一个完整的批次和剩余的尾巴的统计信息。
- 如果批次和尾巴的总和根据约束条件形成一个有效的批次,则尾巴被附加到批次上。
- 如果该批和尾巴违反了约束条件,则运行中的批被最终确定,而尾巴则成为一个新批。
- 如果尾巴本身违反了约束条件,该函数最终确定两个批次,并将尾巴设置为新的批次。
- 如果批次和尾巴满足基于批次大小倍数的约束条件,该函数将当前位置添加到当前批次,并更新当前批次和尾巴的统计信息。
- 最后,该函数根据批处理结束返回从索引列表中分离出来的批处理列表。
以上,就完成了batch sampler的生成过程,结果如下,
当然,后续可能会有shuffle的过程,但每个batch中数据的顺序就不会发生变化了,变化的是各个batch的顺序。