fairseq读代码系列(二)——多语言的采样过程

最近在尝试使用fairseq框架复现多语言机器翻译模型,但是对于其中batch sampler的生成过程还是有一点困惑,fairseq中的batch size与一般的batch size不同,是变长的,即根据指定的max token和句子的长度来确定当前batch中有多少句子,这导致每一个batch中的句子数量是不一致的。因此,本文旨在根据框架代码来探究一下fairseq是如何根据每个语言向的dataset和指定的max token/max sentence来确定每个batch的batch sampler的

数据集准备

笔者这里采用TED TALKS数据集作为实验数据集,选择了其中的四个语言向进行实验,分别是ar-en,he-en,it-en,de-en,参考的论文为When and Why Are Pre-Trained Word Embeddings Useful for Neural Machine Translation?,数据集地址如下:

https://github.com/neulab/word-embeddings-for-nmt

下载数据集并解压后,显示是各个语言向汇总的数据集,如下所示,
在这里插入图片描述
我们需要从中提取出我们需要的语言向,可以参考github仓库中的这个文件:

https://github.com/neulab/word-embeddings-for-nmt/blob/master/ted_reader.py

指定src_langtrg_lang就能够提取出对应的语言对,注意ted data的路径需要指定正确,为了后续学习bpe分词,我们还需要把所有需要的语言向的训练数据进行合并并命名为train.all文件,可以直接在ted_reader.py中进行修改,加入save_file_all方法并进行调用

def save_file_all(self, path_, split_type, data_type):
    with open(path_, 'a') as fp:
        for line in self.data[split_type][data_type]:
            fp.write(line + '\n')

# save to train.all
obj.save_file_all(output_data_path + "/train.all", split_type='train', data_type='source')
obj.save_file_all(output_data_path + "/train.all", split_type='train', data_type='target')

数据集处理

学习bpe并应用

#!/usr/bin/env bash
# echo 'Cloning fairseq repository...'
# git clone git@github.com:facebookresearch/fairseq.git
bpe=bpe
tmp=TED
rm -r $bpe
mkdir -p $bpe

# learn bpe
python -u ../../fairseq/scripts/spm_train.py \
  --input=$tmp/train.all \
  --model_prefix=spm.bpe \
  --vocab_size=70000 \
  --character_coverage=1.0 \
  --model_type=bpe \
  --num_threads=45 \
  --shuffle_input_sentence

# apply bpe in {}-en
for split in train valid test; do
  # 具体使用时需要根据需要修改语言对
  for src in ar he ru ko it ja zh es fr pt nl tr ro pl bg vi de fa hu; do
    echo ${split} ${src}-en
    python ../../fairseq/scripts/spm_encode.py \
      --model spm.bpe.model \
      --output_format=piece \
      --inputs ${tmp}/${split}.${src}-en.${src} ${tmp}/${split}.${src}-en.en \
      --outputs ${bpe}/${split}.${src}-en.bpe.${src} ${bpe}/${split}.${src}-en.bpe.en
  done
done

二值化处理

#!/usr/bin/env bash
# create share dict
path=data-bin
rm -r $path
mkdir -p $path

# https://github.com/facebookresearch/fairseq/issues/2110#issue-614837309
cut -f1 spm.bpe.vocab | tail -n +4 | sed "s/$/ 100/g" > $path/dict.txt
#for lang in ar de es fa he it nl pl en; do
#  cp $path/dict.txt $path/dict.${lang}.txt
#done

# binarize {}-en
for src in ar he ru ko it ja zh es fr pt nl tr ro pl bg vi de fa hu; do
  echo ${src}-en
  # 首先需要准备fairseq
  fairseq-preprocess \
      --source-lang $src --target-lang en \
      --trainpref bpe/train.${src}-en.bpe \
      --validpref bpe/valid.${src}-en.bpe \
      --testpref bpe/test.${src}-en.bpe \
      --destdir $path \
      --srcdict $path/dict.txt \
      --tgtdict $path/dict.txt
done

训练

训练脚本如下所示:

CUDA_VISIBLE_DEVICES=0 fairseq-train /data2/lypan/ted_data/word-embeddings-for-nmt-master/data-bin \
    --max-epoch 100 \
    --ddp-backend=legacy_ddp \
    --task multilingual_translation --lang-pairs ar-en,he-en,it-en,de-en \
    --arch multilingual_transformer_iwslt_de_en \
    --share-decoders --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' \
    --lr 0.0005 --lr-scheduler inverse_sqrt \
    --warmup-updates 4000 --warmup-init-lr '1e-07' \
    --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \
    --dropout 0.3 --weight-decay 0.0001 \
    --save-dir /data2/lypan/fairseq/checkpoints \
    --max-tokens 4000 \
    --update-freq 8 \
    --tensorboard-logdir /data2/lypan/fairseq/logs \
    --save-interval 5 --keep-best-checkpoints 1

调试

由于直接在terminal调试太不现实,所以笔者就转战pycharm远程连接服务器进行调试。我们首先关注fairseq_task.py中的以下代码

# create mini-batches with given size constraints
batch_sampler = dataset.batch_by_size(
    indices,
    max_tokens=max_tokens,
    max_sentences=max_sentences,
    required_batch_size_multiple=required_batch_size_multiple,
)

此处的indices是一个从0到数据集长度的索引数组,数据集长度指四种语言对中最长的语言对数据集长度(此处是ar-en语言对数据最长,是214111)
在这里插入图片描述
在这里插入图片描述
然后跳转进入batch_by_size方法,可以看到,其实还是没有对batch sampler进行生成,而是又跳转进入了data_utils里的batch_by_size方法,

return data_utils.batch_by_size(
    indices,
    num_tokens_fn=self.num_tokens,
    num_tokens_vec=num_tokens_vec,
    max_tokens=max_tokens,
    max_sentences=max_sentences,
    required_batch_size_multiple=required_batch_size_multiple, fixed_shapes=fixed_shapes,
)

进入data_utils.py后,首先指定了max_tokensmax_sentencesbsz_mult等属性,然后又跳转进入了batch_by_size_fn方法(不要着急,这是最后一个了)

if fixed_shapes is None:
    if num_tokens_vec is None:
        return batch_by_size_fn(
           indices,
           num_tokens_fn,
           max_tokens,
           max_sentences,
           bsz_mult,
        )

我们需要先明确一下num_tokens_fn是传入的哪个函数

    def num_tokens(self, index):
        """Return an example's length (number of tokens), used for batching."""
        return max(
            dataset.num_tokens(self._map_index(key, index))
            for key, dataset in self.datasets.items()
        )

其作用是根据传入index来寻找当前四个语言对对应index的最长的一句话,举个栗子,

在这里插入图片描述

四个语言对中的第0个句子(根据长度排序后的,根据长度排序可以方便padding,一个batch中的数据只要pad到一样的长度就可以,这样可以减少pad的长度),其中最长的一句话的长度为5,所以返回结果为5;四个语言对中的第214110个句子(只有ar-en语言对有对应数据),其中最长的一句话的长度为464,所以返回结果为464。

然后我们再回来看一下batch_by_size_fn方法,这里为了计算速度,代码是使用cython写的,这也就意味着我们打断点是运行不到这里的,所以我们就只能硬读了,代码如下所示,

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_fn(
    np.ndarray[DTYPE_t, ndim=1] indices,
    num_tokens_fn,
    int64_t max_tokens,
    int64_t max_sentences,
    int32_t bsz_mult,
):
	# indices长度
    cdef int32_t indices_len = indices.shape[0]
    # 长度与indices相同的全0数组
    cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
                                                               dtype=np.int64)
    cdef DTYPE_t[:] indices_view = indices
    cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
    cdef int64_t pos
    for pos in range(indices_len):
    	# 记录每个index的各个语言对数据的最大长度
        num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
    return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
        max_sentences, bsz_mult,)

这里面的num_tokens_vec数组非常重要,它与indices数组的长度一致(214111),其中每个位置记录的是各个语言对数据对应index的数据的最大长度,这样假如以后保证最大长度也在max token的限制内,那么其他语言对的数据也肯定在max token的限制内

我们可以看到,它还调用了batch_by_size_vec,这个方法就能返回根据数据集句子长度分割好的每个batch的batch sampler了,方法代码如下:

@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_vec(
    np.ndarray[int64_t, ndim=1] indices,
    np.ndarray[int64_t, ndim=1] num_tokens_vec,
    int64_t max_tokens,
    int64_t max_sentences,
    int32_t bsz_mult,
):
    if indices.shape[0] == 0:
        return []

    assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
        f"Sentences lengths should not exceed max_tokens={max_tokens}"
    )

    cdef int32_t indices_len = indices.shape[0]
    cdef np.ndarray[int32_t, ndim=1] batches_ends = \
            np.zeros(indices_len, dtype=np.int32)
    cdef int32_t[:] batches_ends_view = batches_ends
    cdef int64_t[:] num_tokens_view = num_tokens_vec

    cdef int32_t pos = 0
    cdef int32_t new_batch_end = 0

    cdef int64_t new_batch_max_tokens = 0
    cdef int32_t new_batch_sentences = 0
    cdef int64_t new_batch_num_tokens = 0

    cdef bool_t overflow = False
    cdef bool_t size_matches_with_bsz_mult = False

    cdef int32_t batches_count = 0
    cdef int32_t batch_start = 0
    cdef int64_t tail_max_tokens = 0
    cdef int64_t batch_max_tokens = 0

    for pos in range(indices_len):
        # At every pos we keep stats about the last complete batch [batch_start:batch_end),
        #      and tail [batch_end:pos].
        # 1) Every time when (batch + tail) forms a valid batch
        #      (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
        # 2) When (batch+tail) violates max_tokens or max_sentences constraints
        #      we finalize running batch, and tail becomes a new batch.
        # 3) There is a corner case when tail also violates constraints.
        #      In that situation [batch_end:pos-1] (tail without the current pos)
        #      gets added to the finalized batches, while [pos:pos] becomes a new tail.
        #
        # Important: For the sake of performance try to avoid using function calls within this loop.

        tail_max_tokens = tail_max_tokens \
                            if tail_max_tokens > num_tokens_view[pos] \
                            else num_tokens_view[pos]
        new_batch_end = pos + 1
        new_batch_max_tokens = batch_max_tokens \
                                if batch_max_tokens > tail_max_tokens \
                                else tail_max_tokens
        # 记录新批次的句子数量
        new_batch_sentences = new_batch_end - batch_start
        # 记录新批次的token数量,必须根据最长句子的token数量计算(pad)
        new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens

        overflow = (new_batch_sentences > max_sentences > 0 or
                    new_batch_num_tokens > max_tokens > 0)
        size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
                                      new_batch_sentences % bsz_mult == 0)

        if overflow:
        	# 尾部序列的长度
            tail_num_tokens = tail_max_tokens * \
                    (new_batch_end - batches_ends_view[batches_count])
            # 判断尾部序列有没有超出限制
            tail_overflow = tail_num_tokens > max_tokens > 0
            # 如果尾部序列超出限制,则尾部序列单独成为一个batch
            if tail_overflow:
                batches_count += 1
                batches_ends_view[batches_count] = pos
                tail_max_tokens = num_tokens_view[pos]
            # 如果尾部序列没有超出限制,但是该批和尾部序列超出限制,则该批被最终确定,尾部序列成为一个新批
            batch_start = batches_ends_view[batches_count]
            batches_count += 1
            new_batch_max_tokens = tail_max_tokens

        if overflow or size_matches_with_bsz_mult:
            batches_ends_view[batches_count] = new_batch_end
            batch_max_tokens = new_batch_max_tokens
            tail_max_tokens = 0
    if batches_ends_view[batches_count] != indices_len:
        batches_count += 1
    # Memory and time-efficient split
    return np.split(indices, batches_ends[:batches_count])

该函数用于根据一些约束条件,如最大标记数、最大句子数和批量大小的倍数,将给定的索引列表分成若干批。
该函数的逻辑如下。

  • 该函数遍历索引列表中的每个位置。
  • 对于每个位置,它都会跟踪最后一个完整的批次和剩余的尾巴的统计信息。
  • 如果批次和尾巴的总和根据约束条件形成一个有效的批次,则尾巴被附加到批次上。
  • 如果该批和尾巴违反了约束条件,则运行中的批被最终确定,而尾巴则成为一个新批。
  • 如果尾巴本身违反了约束条件,该函数最终确定两个批次,并将尾巴设置为新的批次。
  • 如果批次和尾巴满足基于批次大小倍数的约束条件,该函数将当前位置添加到当前批次,并更新当前批次和尾巴的统计信息。
  • 最后,该函数根据批处理结束返回从索引列表中分离出来的批处理列表。

以上,就完成了batch sampler的生成过程,结果如下,
在这里插入图片描述
当然,后续可能会有shuffle的过程,但每个batch中数据的顺序就不会发生变化了,变化的是各个batch的顺序。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
FPGA与高速ADC接口是数字信号处理系统中非常重要的一部分,它们可以帮助我们实现高速数据采集、数字信号处理、通信等功能。其中,ADC是数据采集的核心,采样率的高低决定了系统的性能,因此如何有效地与高速ADC接口是一个非常关键的问题。 本文将介绍一种基于Xilinx FPGA与高速ADC接口的实现方案,以250MSPS采样率的ADC9481为例。具体实现过程如下: 1.硬件连接 首先需要将ADC9481与FPGA进行连接。ADC9481有两组LVDS输出,每组包含14位数据和1位时钟。因此需要使用一对LVDS差分信号来传输一组数据和时钟信号,共需要8对LVDS差分信号。 2.时钟配置 ADC的采样率由时钟信号控制,因此需要配置FPGA的时钟使其与ADC时钟同步。ADC9481的时钟频率最高可达500MHz,一般使用LVPECL时钟驱动器来提供时钟信号。 在FPGA端,需要将时钟信号通过BUFG(全局缓存)引脚输入到FPGA的时钟管理单元(MMCM)中,使用MMCM生成与ADC时钟同步的本地时钟信号。 3.数据接收 ADC的数据输出是14位的差分数据,需要通过FPGA的差分输入接口进行接收。在FPGA端,可以使用选择器和寄存器来对数据进行处理和存储。选择器可以选择要写入哪个寄存器,而寄存器则用于存储ADC的采样数据。在这个过程中,需要注意选择器和寄存器的延迟时间,确保数据正常存储。 4.数据处理 ADC采样数据的处理包括去偏置、解码、滤波等操作。其中,去偏置是为了消除ADC的直流偏置,解码是将ADC的输出数据转换成相应的数字量,滤波则是为了去除高频噪声。 在FPGA端,可以使用DSP48E1模块进行数据处理。DSP48E1模块是Xilinx FPGA中的专用数字信号处理模块,它可以进行加、减、乘、除、滤波等操作。在这个过程中,需要注意DSP48E1模块的使用方法和相关参数的设置。 5.数据存储 最后,需要将处理后的数据存储到内存中。在FPGA端,可以使用Block RAM(BRAM)或FIFO等存储器来存储数据。其中,BRAM是单端口存储器,适用于小型数据存储;而FIFO是双端口存储器,适用于大型数据存储。 在存储数据时,需要注意存储器的写时序和容量,确保数据能够正常存储。 通过以上实现步骤,就可以与高速ADC接口进行有效的数据采集和处理。当然,具体实现过程可能因硬件设备和应用场景而有所不同,需要根据实际情况进行调整和优化。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值