摘要:
Masked language model(MLM)预训练任务(如bert)通过使用[MASK]标记随机替换token来破坏输入(corrupted input)。虽然这种做法在下游任务可以得到较好的结果。但是只在大规模计算的前提下是有效的。本文提出了一种更加高效的预训练任务称为Replaced token detection。它不是使用[mask]随机替换,而是使用一个生成器以更加合理的方式去随机替换token来破坏输入。之后,不是去预测被破坏输入的原始标记,而是训练一个判别器,去预测被破坏输入的每一个token是否被生成器替换过。
在相同的模型大小,数据量和计算条件下,Electra学习的上下文表示大大优于bert模型。
Tips:Bert以15%的概率随机替换token去预测。相当于只针对15%的输入token进行了学习,而electra对序列的每一个token都进行了学习。学习的效率更高。(例如有两个同学,A同学比较懒散,每次只随机学习15%的知识。学习效率是15%。B同学很认真,每次学习全部内容。那么A同学要学习很多次将知识学习完。在相同的学习次数下,B同学会学习到更多的知识)。
简介
目前主流的文本表示学习方法可以看做是学习降噪自动编码器。他们扰乱输入后,让模型去恢复原始输入。但是需要更大的计算成本。一般都以15%的几率扰乱输入。
并且ELECTRA还解决了BERT的不匹配问题(MLM任务输入带有mask标记,而下游任务中没有)。MLM任务在ELECTRA只是作为生成器来得到破坏的输入,我们使用判断器学习文本的上下文表示。
Tips:
1)提升学习效率。通常学习文本表示时,学习的时间越久,文本表示的效果越好。之前都是通过扩大样本集和增加训练时间来提升文本表示的效果。ELECTRA是通过提升学习效率来提升文本表示效果。BERT只对输入token的15%进行学习。而ELECTRA对每个输入token对进行学习。
2)解决bert中的不匹配问题(预训练中输入有MASK标记,而下游任务中没有)。
模型结构:
生成器就是一个MLM任务的模型。将带[MASK]标签的input输入生成器中。生成器对[MASK]token的预测值替换[MASK]得到生成器的输入。生成器逐帧判断该输入是否与原始序列相等(二分类问题)。如果一个token被[MASK]掉,但是又被生成器预测对。那么在判别器的训练中认为这个token没有被替换过。就是原始的。
用数学语言描述就是:X是原始输入
X
=
(
x
1
,
x
2
,
⋯
,
x
n
)
X=(x_1,x_2,\cdots,x_n)
X=(x1,x2,⋯,xn),h是文本的上下文表示特征
h
=
(
h
1
,
h
2
,
⋯
,
h
n
)
h=(h_1,h_2,\cdots,h_n)
h=(h1,h2,⋯,hn)(transformer提取的序列特征)。用m来指示哪些token被MASK,被mask的token索引的集合。
x
m
a
s
k
e
d
=
R
E
P
L
A
C
E
(
x
,
m
,
[
M
A
S
K
]
)
x^{masked}=REPLACE(x,m,[MASK])
xmasked=REPLACE(x,m,[MASK])
对于每个被MASK的token
x
t
x_t
xt,MLM生成器认为该位置属于该[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j2UTMM1o-1621245553511)(file:///C:/Users/SFANG~1.FAN/AppData/Local/Temp/msohtmlclip1/01/clip_image012.png)]的概率是
P
G
(
c
∣
x
)
=
e
x
p
(
e
m
b
e
d
(
x
t
)
T
h
G
(
x
t
)
)
/
∑
x
′
e
x
p
(
e
m
b
e
d
(
x
′
)
T
h
G
(
x
t
)
)
P_G(c|x)=\mathop{exp}(embed(x_t)^Th_G(x_t))/\sum_{x'}exp(embed(x')^Th_G(x_t))
PG(c∣x)=exp(embed(xt)ThG(xt))/x′∑exp(embed(x′)ThG(xt))
对于每个
i
∈
m
i\in m
i∈m生成器预测的结果记做
x
^
i
\hat x_i
x^i,用
x
^
i
\hat x_i
x^i替换
x
m
a
s
k
e
d
x^{masked}
xmasked中的
[
M
A
S
K
]
[MASK]
[MASK]标记。得到降噪自动编码机的扰乱输入
x
c
o
r
r
u
p
t
x^{corrupt}
xcorrupt,最后用判别器判别
x
c
o
r
r
u
p
t
x^{corrupt}
xcorrupt中每个位置是否被替换(是否跟原输入x一致)。
这样来看:生成器就是一个多标签分类问题,判别器就是一个多任务分类问题。
损失函数:
分别计算生成器的损失和判别器的损失。两项的联合损失作为损失函数。
生成器:生成器是一个多标签分类问题,其损失是对数似然函数。
L
M
L
M
(
x
,
θ
G
)
=
E
(
∑
i
∈
m
−
l
o
g
P
G
(
x
i
∣
x
m
a
s
k
e
d
)
)
L_{MLM}(x,\theta_G)=E(\sum_{i\in m}-log\ P_G(x_i|x^{masked}))
LMLM(x,θG)=E(i∈m∑−log PG(xi∣xmasked))
判别器:判别器是一个多任务二分类,其损失函数是BCEloss(二分类的交叉熵损失常用于多任务分类的loss)。
L
D
I
S
C
(
x
,
θ
D
)
=
E
(
∑
i
=
1
n
−
I
(
x
i
=
x
i
c
o
r
r
u
p
t
)
l
o
g
D
(
x
c
o
r
r
u
p
t
,
i
)
−
I
(
x
i
≠
x
i
c
o
r
r
u
p
t
)
(
1
−
l
o
g
D
(
x
c
o
r
r
u
p
t
,
i
)
)
)
L_{DISC}(x,\theta_D)=E(\sum_{i=1}^n-I(x_i=x_i^{corrupt})\mathrm{log}D(x^{corrupt},i)-\\I(x_i\neq x_i^{corrupt})(1-\mathrm{log}D(x^{corrupt},i)))
LDISC(x,θD)=E(i=1∑n−I(xi=xicorrupt)logD(xcorrupt,i)−I(xi=xicorrupt)(1−logD(xcorrupt,i)))
联合损失:
L
t
o
t
a
l
=
L
M
L
M
(
x
,
θ
G
)
+
λ
L
D
I
S
C
(
x
,
θ
D
)
L_{total}=L_{MLM}(x,\theta_G)+\lambda L_{DISC}(x,\theta_D)
Ltotal=LMLM(x,θG)+λLDISC(x,θD)
细节
- ELECTRA虽然也有生成器和判别器。但是不同于GAN。GAN的生成器使用的是对抗性训练来欺骗判别器。而ELECTRA使用的极大似然估计学习参数。通过实验这能改善下游任务的结果。
- 不能将判别器的损失传到生成器(因为使用了采样?)。
- 只在预训练使用生成器。在fine-turn的下游任务中,扔掉生成器,只使用判别器作为预训练模型。
模型的一些扩展:
- 参数共享
在预训练阶段,共享生成器和判别器的embedding矩阵。如果生成器和判别器的尺寸大小一致。则生成器和判别的的transformer权重可以共享。但是考虑到效率问题,一般会给生成器的更小尺寸。因为在ELECTRA中,生成器只是用来生成扰乱输入,并不用于下游的任务。所有给他更小的transformer结构,预训练阶段会更加高效。所以一般生成器较小,判别器较大。此时,让生成器和判别器的embedding矩阵共享。
实验验证:使用相同大小的生成器和判别器,训练500k steps。不绑定参数GLUE socre83.6。绑定embedding得分84.3,绑定所有transformers权重84.4。从实验结果假设ELECTRA可以从绑定词嵌入矩阵中受益,因为MLM任务在学习词表示时是非常有用的:因为判别器是一个多任务分类,他每次只更新 x c o r r u p t x^{corrupt} xcorrupt扰乱输入存在token的权重(BCELoss)。而生成器的softmax会更新字典中所有token的权重。
- 更小的生成器
让生成器尺寸更小,预训练的计算量就越小。只减少生成器的层数,保持其他超参不变。也对比了使用简单的“unigram”生成假样本,根据每个字在样例中出现的频率。实现发现生成器是判别器1/4-1/2效果最好。
- 训练算法
尝试了其他的训练算法,效果不好。
二阶段训练法:
(1) 先只训练 L M L M L_{MLM} LMLMn steps。
(2) 使用生成器的权重初始化判别器的权重后,冻结生成器权重,训练判别器权重n steps。
使用对抗的方式训练。
效果都不如ELECTRA的效果好。
其他理解:
- MLM相当于有选择的负采样。MLM会将简单易判的预测正确,只保留难学习的样本。
Bert对token的选择是随机的,所以可能会存在很多简单的token预测。例如“苹果”“苹[MASK]”。这些简单的token预测对模型的学习帮助很小。Electra就设计了用MLM专门挑选难学习的样本进行学习。
- 对Token自身信息的利用
在预训练任务中可以看到自身。Bert的Masked LM用周围的词去预测mask词。只看到了周围的单词。而ELECTRA同时用周围词加上自身去进行二分类任务。
- 判别器学习的文本表示适用于下游任务么?
BERT的多分类任务放到下游任务是合理的。BERT的V分类任务更具一般性。为token生成了一个基于上下文的表示。但ELECTRA判别器的二分类任务会导致隐含空间过早退化。遇到复杂任务不能得到丰富的表示。
- 与GAN的区别
GAN | ELECTRA | |
---|---|---|
训练策略 | 对抗训练 | 联合loss |
负样本 | 生成器生成的都是负样本 | 生成器生成预测正确作为正样本,只有预测错误的是负样本 |
输入 | 随机噪声 | 真实文本 |
梯度 | 梯度可以从D传到G | 梯度不能从D传到G |