动态batch和静态batch的原理和代码详解

原理:

在wenet中支持两种batch的方式,

第一种是常规的batch方案,但该方案当语音长短差异过大时,显存利用率低,同时带来显存oom的风险;

第二种是动态batch,不指定batch_size的大小,只限制了batch中的最大总帧数,这样就能够充分利用显存,同时不会有内存oom的风险。

论文:

https://arxiv.org/pdf/2102.01547.pdf

代码位置:

wenet/processor.py at main · wenet-e2e/wenet · GitHub

代码实现

dynamic_batch:

参数:

data: Iterable[{key,feat,label}]

max_frames_in_batch: 整个batch中能包含的总帧数不大于该值

返回:

Iterable[{key,feat,label}]

代码思路:

  步骤一:遍历data,获取每个样本sample

  步骤二:获取该样本的帧数,并更新最大帧数,然后获取padding后的总帧数

for sample in data:
    assert 'feat' in sample
    assert isinstance(sample['feat'], torch.Tensor)
    new_sample_frames = sample['feat'].size(0)
    longest_frames = max(longest_frames, new_sample_frames)
    frames_after_padding = longest_frames * (len(buf) + 1)

  步骤三:若大于batch中的最大帧,则将buf添加到迭代器中,否则一个buf没满,等待下次

if frames_after_padding > max_frames_in_batch:
    yield buf
    buf = [sample]
    longest_frames = new_sample_frames
else:
    buf.append(sample)

  步骤四:遍历结束后将剩余的buf也添加到迭代器中

static_batch:

参数:

  data:Iterable[{key,feat,label}]

  batch_size: batch size

返回:

Iterable[list{key,feat,label}]

代码思路:

  步骤一:遍历data获取每个样本sample

  步骤二:将sample取出来,当buf添加了batch_size数目样本之后,将buf传给迭代器,然后清空buf

buf = []
for sample in data:
    buf.append(sample)
    if len(buf) >= batch_size:
        yield buf
        buf = []

  步骤三:遍历结束后将剩余的buf也添加到迭代器中

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值