《Knowledge Distillation for BERT Unsupervised Domain Adaptation》阅读笔记
原文链接:Paper
摘要
预训练语言模型BERT在一系列自然语言处理任务中带来了显著的性能改进。由于该模型是在不同主题的大型语料库上训练的,因此它在训练(源数据)和测试(目标数据)的数据分布不同但共享相似性的领域转移问题上表现出强大的性能。尽管与以前的模型相比有了很大的改进,但由于域偏移的存在,它仍然存在性能下降的问题。为解决该问题,本文提出一种简单而有效的无监督领域自适应方法——对抗自适应与蒸馏(AAD),将对抗判别领域自适应(ADDA)框架与知识蒸馏相结合。在30个领域对的跨领域情感分类任务中评估了此方法,提高了文本情感分类中无监督领域自适应的最先进性能。
一、简单介绍
在本文中,作者提出了一种新的针对预训练语言模型的对抗域自适应方法,称为蒸馏对抗自适应(AAD)。这项工作是在Tzeng等人提出的名为对抗性歧视域适应(ADDA)的框架之上完成的。作者观察到,当ADDA框架应用于BERT模型时,会发生灾难性的遗忘,而不是在应用于深度卷积神经网络时。在ADDA中,微调的源模型被用作初始化,以防止目标模型因为没有标签信息而学习退化解。不幸的是,这种方法本身并不能防止BERT中的灾难性遗忘,从而导致随机分类性能。为了克服这个问题,作者采用了知识蒸馏方法,该方法主要是通过从大模型转移知识来提高小模型的性能。作者发现,这种方法可以作为一种正则化方法,在保持源数据学习到的信息的同时,使得到的模型具有域自适应能力,避免过拟合。
二、相关工作
2.1 无监督域自适应
无监督域自适应,顾名思义,目标域的样本是没有标注的。先前的其它方法此处省略不述。
ADDA 被提出作为一个对抗框架,包括判别建模、非联合权重共享和基于GAN的损失。首先用标记的源数据训练源编码器,并将权重复制到目标编码器。然后,像原始GAN设置一样,在双人游戏中交替优化目标编码器和鉴别器。鉴别器学习区分目标表示和源表示,而编码器学习欺骗鉴别器。Chadha等人通过修改鉴别器来改进ADDA框架,以共同预测源标签并将目标域的输入区分为半监督GAN。作者的研究与Chadha等人的工作相似,因为它也在对抗性适应步骤中使用了源信息。然而,不同之处在于,作者的方法使用的是知识蒸馏,是在网络中使用源信息的手段,而不是直接使用源标签。
2.2 知识蒸馏
知识蒸馏(KD)最初是一种模型压缩技术,旨在训练一个紧凑的模型(学生),以便将训练有素的较大模型(教师)的知识转移到学生模型中。KD可以通过最小化以下目标函数来表示:
L
K
D
=
t
2
×
∑
k
−
s
o
f
t
m
a
x
(
z
k
T
/
t
)
×
l
o
g
(
s
o
f
t
m
a
x
(
z
k
S
/
t
)
)
(1)
L_{KD}=t^2\times\sum_k-softmax(z_k^T/t)\times log(softmax(z_k^S/t)) \tag{1}
LKD=t2×k∑−softmax(zkT/t)×log(softmax(zkS/t))(1)
其中
z
S
z^S
zS和
z
T
z^T
zT分别是学生模型和教师模型预测的对数,温度系数
t
t
t控制知识转移的程度。
在监督学习中,标准的训练目标是最小化模型预测概率分布与单热编码标签真实概率分布之间的交叉熵损失。然而,这个目标很容易导致重复训练时期的过拟合。由于较大的 t t t值产生较软的概率分布,知识蒸馏与领域自适应方法结合可以缓解这个问题。
2.3 来自Transformers的双向编码器表示
BERT是一种自监督的方法,用于预训练深度transformer编码器。BERT模型是在一个大型语料库上使用屏蔽语言建模和下一句预测进行训练的。在实验中,作者使用BERT、distilBERT和RoBERTa来评估他们的方法。
三、带有蒸馏的对抗域适应
源域数据:
X
S
=
{
(
x
s
i
)
}
i
=
0
N
s
,
y
S
=
{
(
y
s
i
)
}
i
=
0
N
s
X_S=\{(x_s^i)\}_{i=0}^{N_s},y_S=\{(y_s^i)\}_{i=0}^{N_s}
XS={(xsi)}i=0Ns,yS={(ysi)}i=0Ns with
(
x
s
,
y
s
)
∼
(
X
S
,
Y
S
)
(x_s,y_s)\sim(\mathbb{X}_S,\mathbb{Y}_S)
(xs,ys)∼(XS,YS)
目标域数据:
X
T
=
{
(
x
t
i
)
}
i
=
0
N
t
X_T=\{(x_t^i)\}_{i=0}^{N_t}
XT={(xti)}i=0Nt with
x
t
∼
X
T
x_t\sim\mathbb{X}_T
xt∼XT
源编码器:
E
s
(
x
)
E_s(x)
Es(x),其中
x
x
x是网络的输入
目标编码器:
E
t
(
x
)
E_t(x)
Et(x)
将源编码器输出映射到类概率的分类器函数:
C
C
C
将编码器输出(源或目标)映射到域概率的判别器函数:
D
D
D
假设目标数据与源数据共享相同的标签空间。在无监督域自适应中,目标是在不访问目标标签的情况下,通过学习最小化源数据表示与目标数据表示之间的距离来获得更好的目标数据性能。
本文提出的方法包括以下三个步骤:
- 在源数据上训练源编码器和分类器
- 通过对抗性训练和蒸馏使目标编码器的表示与源表示对齐
- 使用自适应的目标编码器和训练好的分类器对目标数据进行推断
3.1 步骤1:对源编码器和分类器进行微调
获得标记的源数据后,首先使用标准交叉熵损失对
X
S
X_S
XS和
y
S
y_S
yS上的源编码器
E
s
E_s
Es和分类器
C
C
C进行微调:
m
i
n
E
s
,
C
L
S
(
X
S
,
y
S
)
=
E
(
x
s
,
y
s
)
∼
(
X
S
,
Y
S
)
−
∑
k
=
1
K
1
[
k
=
y
s
]
l
o
g
C
(
E
s
(
x
s
)
)
(2)
\underset{E_s,C}{min}\mathcal{L}_S(X_S,y_S)=\mathbb{E}_{(x_s,y_s)\sim(\mathbb{X}_S,\mathbb{Y}_S)}-\sum_{k=1}^K\mathbb{1} _{[k=y_s]}logC(E_s(x_s)) \tag{2}
Es,CminLS(XS,yS)=E(xs,ys)∼(XS,YS)−k=1∑K1[k=ys]logC(Es(xs))(2)
其中
K
K
K是类的个数。在用微调后的源编码器参数初始化目标编码器参数后,冻结源编码器参数和分类器。
3.2 步骤2:通过蒸馏对抗性适应来适应目标编码器
在原始GAN设置中交替训练目标编码器和鉴别器。这步可以用图1里的step 2-(a)中的无约束优化来表示:
m
i
n
D
L
d
i
s
(
X
S
,
X
T
)
=
E
x
s
∼
X
S
−
l
o
g
D
(
E
s
(
x
s
)
)
+
E
x
t
∼
X
T
−
l
o
g
(
1
−
D
(
E
t
(
x
t
)
)
)
,
m
i
n
E
t
L
g
e
n
(
X
T
)
=
E
x
t
∼
X
T
−
l
o
g
D
(
E
t
(
x
t
)
)
(3)
\begin{aligned} &\underset{D}{min}\mathcal{L}_{dis}(X_S,X_T)=\mathbb{E}_{x_s\sim \mathbb{X}_S}-logD(E_s(x_s))+\mathbb{E}_{x_t\sim \mathbb{X}_T}-log(1-D(E_t(x_t))), \\&\underset{E_t}{min}\mathcal{L}_{gen}(X_T)=\mathbb{E}_{x_t\sim \mathbb{X}_T}-logD(E_t(x_t)) \tag{3} \end{aligned}
DminLdis(XS,XT)=Exs∼XS−logD(Es(xs))+Ext∼XT−log(1−D(Et(xt))),EtminLgen(XT)=Ext∼XT−logD(Et(xt))(3)
由于它具有来自源编码器的未绑定权重,因此目标编码器可以更灵活地学习特定的域特征。然而,由于无法获得类别标签和与原始任务的不同而导致的随机分类,该公式容易导致灾难性遗忘。
为了增强对抗性训练的稳定性,可以考虑直接使用源标签作为监督学习方法。然而,尽管此方法可能预防对抗性适应中的模式崩溃,但这可能导致模型过度拟合源域数据。而另一方法,知识蒸馏,可以为模型提供对抗适应的灵活性和保留大的温度值
t
t
t的类信息的能力。因此,为模型引入如图1 step 2-(b)中的知识蒸馏损失:
L
K
D
(
X
S
)
=
t
2
+
E
x
s
∼
X
S
∑
k
=
1
K
−
s
o
f
t
m
a
x
(
z
k
S
/
t
)
×
l
o
g
(
s
o
f
t
m
a
x
(
z
k
T
/
t
)
)
(4)
\mathcal{L}_{KD}(X_S)=t^2+\mathbb{E}_{x_s\sim \mathbb{X}_S}\sum_{k=1}^K-softmax(z_k^S/t)\times log(softmax(z_k^T/t)) \tag{4}
LKD(XS)=t2+Exs∼XSk=1∑K−softmax(zkS/t)×log(softmax(zkT/t))(4)
其中
z
S
=
C
(
E
s
(
x
s
)
)
,
z
T
=
C
(
E
t
(
x
s
)
)
z^S=C(E_s(x_s)),z^T=C(E_t(x_s))
zS=C(Es(xs)),zT=C(Et(xs))。
因此,训练目标编码器的最终目标函数为:
m
i
n
E
t
L
T
(
X
S
,
X
T
)
=
L
g
e
n
(
X
T
)
+
L
K
D
(
X
S
)
(5)
\underset{E_t}{min}\mathcal{L}_T(X_S,X_T)=\mathcal{L}_{gen}(X_T)+\mathcal{L}_{KD}(X_S)\tag{5}
EtminLT(XS,XT)=Lgen(XT)+LKD(XS)(5)
最后,将式(3)中的第二个目标函数替换为式(5),然后通过交替最小化目标函数训练鉴别器和目标编码器。
3.3 步骤3:在目标数据上测试目标编码器
如图1的step 3所示,使用经过微调的分类器进行推理,得到的预测结果如下:
y
t
^
=
a
r
g
m
a
x
C
(
E
t
(
x
t
)
)
(6)
\hat{y_t}=argmax \ C(E_t(x_t))\tag{6}
yt^=argmax C(Et(xt))(6)
实验
代码资源:Code
个人见解
本论文思路和清晰,简单来说,就是针对对抗域适应的生成器(本文中称为目标编码器)进行优化,利用知识蒸馏的手段优化了灾难性遗忘的问题。
个人认为论文中有一处小错误,即公式(4)中, z k S z_k^S zkS和 z k T z_k^T zkT的位置反了,它与公式(1)是相矛盾的。
以上仅个人见解,如有不得当之处还望指出。