本文主要介绍wenet的动态chunk设计技巧
在wenet的各类介绍中很少有人主动单独介绍wenet的chunk的实现原理,接下来主要谈谈以下几点:
一:wenet是如何通过chunk的设计来实现流式和非流式一起训练;
二:动态chunk是如何设计的;
三:动态左chunk的设计原理;
非流式解码很好理解,讲chunk_size设置为batch中的max_frame即可。而为了实现非流式解码,wenet在encoder中使用了基于chunk的attention,并允许各个batch使用不同的chunk大小,该操作主要通过attention mask来实现。
代码位置:
https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
代码思路:
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,(1, )).item()
针对问题一:wenet是如何通过chunk的设计来实现流式和非流式一起训练?
方法:从下面代码中我们可以看出有一半的训练是用流式进行训练,而剩下的一半数据是用非流式训练。
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
针对问题二:动态chunk是如何设计的?
chunk_size取了[1-25]中随机整数值
chunk_size = chunk_size % 25 + 1
针对问题三:动态左chunk的设计原理
目的:保证了模型额能够在解码时不会因为chunk设置的具体值所带来效果大幅变差的情况。
方法:根据参数调整所关注的左侧chunk的数量,如果配置该use_dynamic_left_chunk参数为true的话,则先根据chunk_size算出左侧还有多少个chunk,然后随机设置左侧的chunk数;
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,(1, )).item()