BlockShuffle-一个使模型训练速度提升20%的Trick

BlockShuffle,就是在训练过程中使用分块打乱替代随机打乱的一种方法,即将原始数据按照数据长度进行排序,然后进行batch划分,在对batch训练进行打乱。这样操作,可以减少数据padding长度,缩短训练时长

注意:该方法适用的前提是数据输入为变长。(不适合将所有数据padding到模型最大长度的代码)

举例说明

这里简单举个例子,大家就可以理解BlockShuffle为什么可以提高训练速度。

假如,数据长度为[1,2,3,4,512,512,512,512],训练时batch_size大小为2。

当采用随机打乱进行模型训练时,有一种可能是将数据分成[[1,512],[2,512],[3,512],[4,512]],模型训练需要batch内的数据等长,因此padding过后的数据长度为[[512,512],[512,512],[512,512],[512,512]]

当采用分块打乱进行模型训练时,会先对数据进行排序,再按照batch_size进行切割,数据为[[1,2],[3,4],[512,512],[512,512]],进行padding过后的数据长度为[[2,2],[4,4],[512,512],[512,512]]

由于序列长度越长,训练时间越久,因此分块打乱的训练时长要比随机打乱的时长短。因此,理论上,当数据长度方差越大时,分块打乱越省时。

代码实现

Pytorch的实现代码,如下(结合DataSet和DataLoader):

from torch.utils.data.dataloader import _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter
import random
from torch.utils.data import Dataset, DataLoader
from itertools import chain


class BlockShuffleDataLoader(DataLoader):
    def __init__(self, dataset: Dataset, sort_key, sort_bs_num=None, is_shuffle=True, **kwargs):
        """
        初始化函数,继承DataLoader类
        Args:
            dataset: Dataset类的实例,其中中必须包含dataset变量,并且该变量为一个list
            sort_key: 排序函数,即使用dataset元素中哪一个变量的长度进行排序
            sort_bs_num: 排序范围,即在多少个batch_size大小内进行排序,默认为None,表示对整个序列排序
            is_shuffle: 是否对分块后的内容,进行随机打乱,默认为True
            **kwargs:
        """
        assert isinstance(dataset.data_set, list), "dataset为Dataset类的实例,其中中必须包含dataset变量,并且该变量为一个list"
        super().__init__(dataset, **kwargs)
        self.sort_bs_num = sort_bs_num
        self.sort_key = sort_key
        self.is_shuffle = is_shuffle

    def __iter__(self):
        self.dataset.data_set = self.block_shuffle(self.dataset.data_set, self.batch_size, self.sort_bs_num,
                                                   self.sort_key, self.is_shuffle)
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

    @staticmethod
    def block_shuffle(data, batch_size, sort_bs_num, sort_key, is_shuffle):
        # 将数据按照batch_size大小进行切分
        tail_data = [] if len(data) % batch_size == 0 else data[-len(data) % batch_size:]
        data = data[:len(data) - len(tail_data)]
        assert len(data) % batch_size == 0
        # 获取真实排序范围
        sort_bs_num = len(data) // batch_size if sort_bs_num is None else sort_bs_num
        # 按照排序范围进行数据划分
        data = [data[i:i + sort_bs_num * batch_size] for i in range(0, len(data), sort_bs_num * batch_size)]
        # 在排序范围,根据排序函数进行降序排列
        data = [sorted(i, key=sort_key, reverse=True) for i in data]
        # 将数据根据batch_size获取batch_data
        data = list(chain(*data))
        data = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
        # 判断是否需要对batch_data序列进行打乱
        if is_shuffle:
            random.shuffle(data)
        # 将tail_data填补回去
        data = list(chain(*data)) + tail_data
        return data

本代码,主要继承DataLoader类,并要求输入的DataSet类必须包含data_set成员变量,data_set存放的是所有数据,类型为list

sort_key为排序函数,即使用data_set中元素的哪一个变量的长度进行排序,例如:

sort_key=lambda x: len(x["input_ids"])

即,对元素中input_ids变量进行排序。(这里我们每一个元素为一个dict,如果为list,请自行修改)

(PaddlePaddle2.0的实现代码,大家可以按需修改)

实验结果

本人写了一个基于BERT的情绪识别的代码,进行了速度和效果测试,代码见:

https://github.com/liucongg/BlockShuffleTest

训练参数如下:训练数据大小为27768,采用BERT-Base模型,Batch_Size为32,模型最大长度256。

速度提高了多少呢?

当训练轮数为2,采用原始随机打乱,所耗费时长为590秒(运行train.py的train_ori_time函数);采用分块打乱,所耗费时长为458秒(运行train.py的train_block_shuffle_time函数)。

速度提升了(590-458)/ 590* 100% = 22.37%

速度提升了,那么效果是否会下降呢?

在这份代码上,训练了5个epoch,训练参数如上,采用原始随机打乱时,dev上的acc为0.7785,采用分块打乱时,dev上的acc为0.7849。效果没有下降,反而提高了一丢丢(应该是随机性导致的)。

运行train.py的train函数,同时修改is_block_shuffle配置。

注意:有一种特殊情况,会导致BlockShuffle不能收敛。假如有0和1两种标签,而恰恰长度短的数据标签全为0,长度长的标签全为1,导致所有batch序列中每个batch仅有一种标签,使模型无法收敛。个人觉得这种小概率事件,基本不会发生,如果发生了,就是命不好。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值