一、整体流程
二、feature encoder
理解Conv1d
nn.Conv1d(in_channels=5, out_channels=20, kernel_size=3, stride=2)
假设输入input=(batch, in_channels, in_len)
- batch=1
- in_channels=5,对应向量大小,比如word embedding
- in_len=10,对应word的个数
cnn内部kernel=(in_channels, kernel_size)=(5,3),相当于对(5,10)的输入,使用卷积核(5,3)、步长为2进行卷积,最后一个kernel的输出为(1,(10-3)/2+1)=(1,4),最后所有的输出为(batch, out_channels, out_len)=(1,20,4)。当前CNN层的out_channels作为下一层CNN的in_channels。实际的卷积参数有out_channels*(in_channels/groups)*kernel_size
结构解析
(feature_extractor): ConvFeatureExtractionModel(
(conv_layers): ModuleList(
(0): Sequential(
(0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
(1): Dropout(p=0.0, inplace=False)
(2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
(3): GELU()
)
(1): Sequential(
(0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
(1): Dropout(p=0.0, inplace=False)
(2): GELU()
)
(2): Sequential(
(0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
(1): Dropout(p=0.0, inplace=False)
(2): GELU()
)
(3): Sequential(
(0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
(1): Dropout(p=0.0, inplace=False)
(2): GELU()
)
(4): Sequential(
(0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
(1): Dropout(p=0.0, inplace=False)
(2): GELU()
)
(5): Sequential(
(0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
(1): Dropout(p=0.0, inplace=False)
(2): GELU()
)
(6): Sequential(
(0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
(1): Dropout(p=0.0, inplace=False)
(2): GELU()
)
)
)
文章使用了7层的CNN,步长为(5,2,2,2,2,2,2),卷积核宽度为(10,3,3,3,3,2,2),假设输入语音的长度为(1,x):
- cnn0 (x-10)/5+1=x/5-1
- cnn1 ((x/5-1)-3)/2+1=x/10-1
- cnn2 x/20-1
- cnn3 x/40-1
- cnn4 x/80-1
- cnn5 x/160
- cnn6 x/320
对于1s的语音长度对应矩阵(1,16000),论文中的channels大小设置的为512,对应的输出为(512,16000/320)=(512,50),可以得到50个512维的向量,相当于每20ms产生一个512维的特征向量。
归一化
在CNN0使用了GroupNorm(对整个sequence的每一个channel做归一化),在CNN6的输出使用torch.nn.LayerNorm(对每个batch内的channel做归一化)
三、context网络
主要包括两部分:
- conv
使用conv1d替换原来的positional embedding,内核大小为128,group为16,输入是(768,100),最后的输出为(768,100) - transformer
结构参考上图
四、mask部分
一帧被选为mask区域起点的概率p是0.065,mask长度M为10(文章的表述方法),对应的训练参数为–mask-length 10 --mask-prob 0.65
- 获取mask区域的个数num_mask
num_mask=语音长度/mask_length*mask_prob,由于存在overlap,所以最终mask的区域会少 - mask lengths有四种计算方式:
- static
- uniform
- normal
- poisson
- 随机选取num_mask的起点,做mask,mask使用的向量
torch.FloatTensor(args.encoder_embed_dim).uniform_()
最终的效果大概有49%的帧会做mask,平均mask span的长度为14.7帧
五、量化模块
product quantization,可以通过这个实例来理解具体原理。
整体流程
文章参数:码本个数G=2,每个码本的条目个数V=320,条目的维度d/G=256/2=128
代码参数:G=latent_groups,V=latent_vars,d=vq_dim
具体的计算流程如上图所示,前向的时候直接找出来最大值对应的码本中的条目,相当于是一个离散的操作,但是这个步骤不可导,无法进行反向传播,为了解决这个问题,采用了gumbel softmax操作
量化原理
使用固定的码本来替换原有的向量,从码本中选取对应的条目有两种方法[1]:gumbel-softmax和k-means clustering。
如果向量只用一个码本来表示向量,容易导致model collapse,即只有部分条目会被使用,为了解决这个问题,引入了group的概念,即使用多个码本的条目拼接起来表示原始的向量。结果显示group越大效果越好
gumbel softmax
gumbel softmax的一个理解:
p
g
,
v
=
e
x
p
(
(
l
g
,
v
+
n
v
)
/
τ
)
∑
k
=
1
V
e
x
p
(
(
l
g
,
k
+
n
k
)
/
τ
)
p_{g,v}=\frac{exp((l_{g,v}+n_v)/\tau)}{\sum_{k=1}^V exp((l_{g,k}+n_k)/\tau)}
pg,v=∑k=1Vexp((lg,k+nk)/τ)exp((lg,v+nv)/τ)
其中:
- l l l为(2,320)维的向量
- n = − l o g ( − l o g ( u ) ) n=-log(-log(u)) n=−log(−log(u)), u u u是0到1的均匀采样
- τ \tau τ控制采样结构的分布,越小越接近one-hot向量,latent-tmp=(2,0.5,0.999995),随update_num增加逐渐减小
个人理解:实际用的时候,前向是根据最大值离散取值,反向传播的时候使用gumbel-softmax进行梯度计算,离散采样的本质是获得一个one-hot的向量,就可以做到离散选取码本的特定条目,所有只需要尽量保证softmax以后的值尽量接近one-hot的形式,就可以做到和离散采样尽量的近似,当 τ \tau τ足够小的时候,gumbel softmax的值会很接近one-hot形式的向量,如下:
a=torch.arange(1,10)
#tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
f=F.gumbel_softmax(a.float(), tau=0.1, hard=False)
#tensor([1.4013e-45, 7.5760e-40, 2.4285e-35, 2.9345e-28, 1.0451e-21, 1.8084e-03, 1.3439e-02, 9.7847e-12, 9.8475e-01])
f=F.gumbel_softmax(a.float(), tau=1, hard=False)
#tensor([3.2670e-03, 2.5322e-04, 4.6802e-03, 8.1748e-03, 7.8111e-03, 1.6325e-01, 2.9331e-02, 2.4782e-01, 5.3542e-01])
个人理解
原来的预测的目标是特征,是一个连续值,预测起来难度会大一些,量化的目的是把预测的目标由连续值修改为离散值,离散值的空间大小为320^2=102400。量化模块首先将连续值的特征映射到离散值,然后将获得的离散值作为预测的目标。
六、loss函数
L
=
L
m
+
α
L
d
+
β
L
f
L=L_m+\alpha L_d+\beta L_f
L=Lm+αLd+βLf
α
\alpha
α设为0.1,
β
\beta
β设为10,对应训练参数–loss-weights ‘[0.1, 10]’
三个loss分别为:Contrastive Loss、Diversity Loss、针对feature encoder的L2 penalty
Contrastive Loss
类似于CPC的loss,主要用来做mask的预测
L
m
=
−
l
o
g
e
x
p
(
s
i
m
(
c
t
,
q
t
)
/
κ
)
∑
q
^
∼
Q
t
e
x
p
(
s
i
m
(
c
t
,
q
^
)
/
κ
)
L_m=-log\frac{exp(sim(c_t,q_t)/\kappa)}{\sum_{\hat{q}\sim Q_t}exp(sim(c_t,\hat{q})/\kappa)}
Lm=−log∑q^∼Qtexp(sim(ct,q^)/κ)exp(sim(ct,qt)/κ)
其中
s
i
m
(
a
,
b
)
=
a
T
b
/
∥
a
∥
∥
b
∥
sim(a,b)=a^Tb/\left\|a\right\|\left\|b\right\|
sim(a,b)=aTb/∥a∥∥b∥
Diversity Loss
目的是为了encourage the model to use the codebook entries equally often.
L
d
=
1
G
V
∑
g
=
1
G
∑
v
=
1
V
p
ˉ
g
,
v
l
o
g
p
ˉ
g
,
v
L_d=\frac{1}{GV}\sum_{g=1}^G\sum_{v=1}^V \bar{p}_{g,v}log\bar{p}_{g,v}
Ld=GV1g=1∑Gv=1∑Vpˉg,vlogpˉg,v
个人理解:就是每个group里面各个条目出现的概率尽量平均,这样上面的目标函数最大。
代码实现
七、参考文献
[1].VQ-WAV2VEC: SELF-SUPERVISED LEARNING OF DISCRETE SPEECH REPRESENTATIONS