对抗判别式领域自适应
论文链接:https://ieeexplore.ieee.org/document/8099799/
文献:E. Tzeng, J. Hoffman, K. Saenko, and T. Darrell, “Adversarial discriminative domain adaptation,” in Proceedings of the 30th IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2017, 2017, vol. 2017-July, pp. 2962–2971, doi: 10.1109/CVPR.2017.316.
领域自适应技术在不匹配说话人识别的问题中非常有效。这篇文章是图像领域的判别式对抗自适应方法,也同样可以迁移至说话人识别领域。
文章目录
摘要
目的:在无监督领域自适应问题中,对抗方法是减少训练分布和测试分布之间的差别、改善泛化性能的有效又短。但如今的生成式方法在判别判别任务上性能不佳,而判别式方法,尽管能处理较大的域变换,但还未充分利用对抗生成网络的损失函数。
数据与方法:作者提出了对抗自适应方法的广义框架,进而在此基础上,提出了对抗判别式领域自适应方法 (ADDA),该方法涉及判别式模型、无共享权重和 GAN 损失。提出的方法在三个任务上进行测试:无监督领域自适应基准任务 - 数字(MNIST、USPS 和 SVHN)、跨模态的自适应学习任务(NYUD)和跨视觉域的自适应学习任务(标准 Office - amazon, webcam, dslr)。
结果:在基准任务上,当源域和目标域相近时(
MNIST
⇋
USPS
\text{MNIST}\leftrightharpoons \text{USPS}
MNIST⇋USPS),ADDA 与生成式方法相当;在跨模态的自适应学习任务上,大部分类别的分类性能得到了显著改进,但也存在性能降低的类别;在跨视觉域的自适应学习任务上,不同模型的ADDA都获得了一致的性能提升。
1. 引言
深度学习在各种任务和视觉领域上能够学到各种表示,然而,领域变化/领域偏差导致这些表示在新的数据集和任务上的泛化效果不佳。针对这一问题,典型的解决方法是在针对任务的数据集上进行精调。但是,获取用于精调深度网络的大规模数据是非常困难。
领域自适应方法可以减轻领域变化的有害影响,其思想是学习到两个领域的共同特征空间。其手段是以实现最小化领域变化的距离为目标优化表示,其中领域变化的衡量方法是:
- maximum mean discrepancy (MMD):计算两个域均值之差的范数;
- correlation distances (CORAL):匹配两个分布的均值和协方差;
- 对抗损失:例如 GAN 方法的损失、反向梯度、领域混淆的损失。
对抗自适应方法是通过关于领域判别器的对抗优化目标来最小化的领域散度距离,常见的案例如生成对抗网络让生成器产生让判别器误导的图片。不同的对抗自适应方法的设计需要考虑三点:
- 是否使用生成器,
- 采用何种损失函数,
- 是否共享跨域权重。
2. 广义对抗自适应
对抗无监督自适应方法的一般框架:
-
源域: X s X_s Xs 及其标签 Y s Y_s Ys 由 p s ( x , y ) p_s(x,y) ps(x,y) 获得,
-
目标域: X t X_t Xt 由 p t ( x , y ) p_t(x,y) pt(x,y) 获得,但是无标签,
-
目标:学习目标域表示 M t M_t Mt 和目标域分类器 C t C_t Ct,使得能够在测试阶段正确地分类目标样本,即便在缺少域注释的情况下;
-
领域自适应方法:学习源域的表示映射 M s M_s Ms 和源域的分类器 C s C_s Cs,然后学会适应在目标域上的模型使用。
-
对抗自适应方法:主要目标是正则化源域和目标域映射( M s M_s Ms 和 M t M_t Mt)的学习过程,以实现最小化经验性的源域和目标域映射( M s ( X s ) M_s(X_s) Ms(Xs) 和 M t ( X t ) M_t(X_t) Mt(Xt))分布之间的距离。最后使得源域分类器 C s C_s Cs 能够师姐用于目标域表示,而不需要再单独为目标域训练分类器,即 C = C s = C t C = C_s = C_t C=Cs=Ct。
其中分类器的监督损失:
min
M
s
,
C
L
cls
(
X
s
,
Y
s
)
=
−
E
(
x
s
,
y
s
)
∼
(
X
s
,
Y
s
)
∑
k
=
1
K
1
[
k
=
y
s
]
log
C
(
M
s
(
x
s
)
)
\begin{aligned} &\min\limits_{M_s,C}\,\mathcal{L}_{\text{cls}}(\mathbf{X}_s,Y_s)=\\ &\quad\quad-\mathbb{E}_{(\mathbf{x}_s,y_s)\sim(\mathbf{X}_s,Y_s)}\sum\limits_{k=1}^K\mathbb{1}_{[k=y_s]}\log{C(M_s(\textbf{x}_s))} \end{aligned}
Ms,CminLcls(Xs,Ys)=−E(xs,ys)∼(Xs,Ys)k=1∑K1[k=ys]logC(Ms(xs))
判别器的监督损失:
L
adv
D
(
X
s
,
X
t
,
M
s
,
M
t
)
=
−
E
x
s
∼
X
s
[
log
D
(
M
s
(
x
s
)
)
]
−
E
x
t
∼
X
t
[
log
(
1
−
D
(
M
t
(
x
t
)
)
)
]
\begin{aligned} &\mathcal{L}_{\text{adv}_D}(\mathbf{X}_s,\mathbf{X}_t,M_s,M_t)=\\ &\quad\quad-\mathbb{E}_{\mathbf{x}_s\sim\mathbf{X}_s}\left[\log{D(M_s(\textbf{x}_s))}\right]\\ &\quad\quad\quad\quad-\mathbb{E}_{\mathbf{x}_t\sim\mathbf{X}_t}\left[\log{(1-D(M_t(\textbf{x}_t)))}\right] \end{aligned}
LadvD(Xs,Xt,Ms,Mt)=−Exs∼Xs[logD(Ms(xs))]−Ext∼Xt[log(1−D(Mt(xt)))]
作者给出了领域对抗技术的一般公式,依次分别是 1)判别器监督损失,2)对抗映射损失,3)(领域)映射优化约束。其中第 1 点的判别器监督损失以在上述模型列出。
min
D
L
adv
D
(
X
s
,
X
t
,
M
s
,
M
t
)
min
M
s
,
M
t
L
adv
M
(
X
s
,
X
t
,
D
)
s
.
t
.
ψ
(
M
s
,
M
t
)
\begin{aligned} \min\limits_D\,\,&\mathcal{L}_{\text{adv}_D}(\mathbf{X}_s,\mathbf{X}_t,M_s,M_t)\\ \min\limits_{M_s,M_t}\,\,&\mathcal{L}_{\text{adv}_M}(\mathbf{X}_s,\mathbf{X}_t,D)\\ s.t.\,\,&\psi(M_s,M_t) \end{aligned}
DminMs,Mtmins.t.LadvD(Xs,Xt,Ms,Mt)LadvM(Xs,Xt,D)ψ(Ms,Mt)
2.1 源映射与目标映射
ψ ( M s , M t ) \psi(M_s,M_t) ψ(Ms,Mt) 设计思路:
-
最小化两领域各自映射之间的距离,
-
维持目标映射的类别可分性。
神经网络的映射可以采用层结构的符号进行表示,即
ψ
(
M
s
,
M
t
)
≜
{
ψ
i
(
M
s
i
,
M
t
i
)
}
i
∈
{
1
…
n
}
\psi(M_s,M_t)\triangleq\{\psi_i(M_s^i,M_t^i)\}_{i\in\{1\dots n\}}
ψ(Ms,Mt)≜{ψi(Msi,Mti)}i∈{1…n}
神经网络模型的表示约束形式主要有两种:
-
源域和目标域层表示相同,可通过 CNN 权重共享的方式来实现,即
-
ψ i ( M s i , M t i ) = ( M s i = M t i ) \psi_i(M_s^i,M_t^i)=(M_s^i=M_t^i) ψi(Msi,Mti)=(Msi=Mti)
-
所有层的表示都约束(即对称变换)会造成相同的网络处理两个独立的领域,造成优化的条件非常差。
-
-
源域和目标域层表示无约束。
- 层子集表示约束(即约束部分层,非对称变换)允许模型能够学习每个领域的独立参数。
2.2 对抗损失
作者采用了 GAN 损失的形式,即
L
adv
M
(
X
s
,
X
t
,
D
)
=
−
E
x
t
∼
X
t
[
log
D
(
M
t
(
x
t
)
)
]
\mathcal{L}_{\text{adv}_M}(\mathbf{X}_s,\mathbf{X}_t,D)=-\mathbb{E}_{\mathbf{x}_t\sim\mathbf{X}_t}\left[\log{D(M_t(\textbf{x}_t))}\right]
LadvM(Xs,Xt,D)=−Ext∼Xt[logD(Mt(xt))]
3. 对抗判别式领域自适应
在领域对抗方法的广义框架下,有三个设计原则:
-
使用生成式模型还是判别式模型?
-
是否系紧或解开模型权重?
-
使用哪种对抗学习的优化目标?
对三个原则确定各自的方法,就可以获得一种新的领域自适应方法。而其中的判别式方法就是在第一原则中选择判别式模型的方法。
根据上述的三种设计原则,作者提出的对抗判别式领域自适应的三个设计原则分别是:
- 判别式模型,其依据是:
- 生成式模型的参数和判别式自适应任务大多是无关的,
- 生成式模型大多适用于源域和目标域非常相近的场合。
- 权重不共享,在优化过程中会出现的问题和解决的方法:
- 问题:无权重共享、无标签的目标域模型可能会学到一个退化的解?
- 解决方法:预训练的源域模型作为目标表示空间的实例。
- 标准的 GAN 损失。
其整体的损失函数为:
min
M
s
,
C
L
cls
(
X
s
,
Y
s
)
=
−
E
(
x
s
,
y
s
)
∼
(
X
s
,
Y
s
)
∑
k
=
1
K
1
[
k
=
y
s
]
log
C
(
M
s
(
x
s
)
)
min
D
L
adv
D
(
X
s
,
X
t
,
M
s
,
M
t
)
−
E
x
s
∼
X
s
[
log
D
(
M
s
(
x
s
)
)
]
−
E
x
t
∼
X
t
[
log
(
1
−
D
(
M
t
(
x
t
)
)
)
]
min
M
t
L
adv
M
(
X
s
,
X
t
,
D
)
=
−
E
x
t
∼
X
t
[
log
D
(
M
t
(
x
t
)
)
]
\begin{aligned} &\min\limits_{M_s,C}\,\mathcal{L}_{\text{cls}}(\mathbf{X}_s,Y_s)=\\ &\quad\quad-\mathbb{E}_{(\mathbf{x}_s,y_s)\sim(\mathbf{X}_s,Y_s)}\sum\limits_{k=1}^K\mathbb{1}_{[k=y_s]}\log{C(M_s(\textbf{x}_s))}\\ &\min\limits_D\,\,\mathcal{L}_{\text{adv}_D}(\mathbf{X}_s,\mathbf{X}_t,M_s,M_t)\\ &\quad\quad-\mathbb{E}_{\mathbf{x}_s\sim\mathbf{X}_s}\left[\log{D(M_s(\textbf{x}_s))}\right]\\ &\quad\quad\quad\quad-\mathbb{E}_{\mathbf{x}_t\sim\mathbf{X}_t}\left[\log{(1-D(M_t(\textbf{x}_t)))}\right]\\ &\min\limits_{M_t}\,\,\mathcal{L}_{\text{adv}_M}(\mathbf{X}_s,\mathbf{X}_t,D)=\\ &\quad\quad-\mathbb{E}_{\mathbf{x}_t\sim\mathbf{X}_t}\left[\log{D(M_t(\textbf{x}_t))}\right] \end{aligned}
Ms,CminLcls(Xs,Ys)=−E(xs,ys)∼(Xs,Ys)k=1∑K1[k=ys]logC(Ms(xs))DminLadvD(Xs,Xt,Ms,Mt)−Exs∼Xs[logD(Ms(xs))]−Ext∼Xt[log(1−D(Mt(xt)))]MtminLadvM(Xs,Xt,D)=−Ext∼Xt[logD(Mt(xt))]
这个方法的训练过程是一个分阶段的过程:
-
优化分类器和对抗(源域)映射 C C C 和 M s M_s Ms,
-
优化判别器和对抗(目标域)映射 D D D 和 M t M_t Mt。
4. 无监督领域自适应实验
实验分为三个任务:
- 无监督领域自适应基准任务:数字数据集,MNIST、USPS 和 SVHN。
- 跨视觉域的自适应学习任务:标准 Office 数据集。
- 跨模态的自适应学习任务:NYUD 数据集。
数据集的示例如下图所示:
实验涉及 4 种算法分别是 Gradient reversal、Domain confusion、CoGAN 和作者提出的 ADDA。
方法 | 基础模型 | 权重共享 | 对抗损失 |
---|---|---|---|
Gradient reversal | discriminative | shared | minimax |
Domain confusion | discriminative | shared | confusion |
CoGAN | generative | unshared | GAN |
ADDA | discriminative | unshared | GAN |
4.1 无监督领域自适应基准任务
-
数据:数字数据集,MNIST, USPS, SVHN,10个类别:0-9。
- MNIST:2000
- USPS:1800
- SVHN:所有的训练集
-
迁移任务:MNIST → \to → USPS, USPS → \to → MNIST, SVHN → \to → MNIST
-
方法及其设定:
- 基础模型 LeNet,
- ADDA 判别器:500 FC + ReLU + 500 FC + ReLU + 输出,
- 优化参数:Adam、10000 次迭代、学习率 0.0002、 β 1 = 0.5 \beta_1=0.5 β1=0.5、 β 2 = 0.999 \beta_2=0.999 β2=0.999、每个 batch 256 张图片、128 张图片/领域、灰度图、28$\times$28 像素。
- 对比方法:Gradient reversal、Domain confusion、CoGAN。
-
实验结果( ⋆ ^\star ⋆ 表示最佳性能):生成式方法 CoGAN 适用于相似的数据集,在相似度低的数据集上,未收敛;判别式方法 ADDA 获得了显著的性能效果。
方法 MNIST → \to → USPS USPS → \to → MNIST SVHN → \to → MNIST 尽源域数据 0.752 ± 0.016 0.752 \pm 0.016 0.752±0.016 0.571 ± 0.017 0.571 \pm 0.017 0.571±0.017 0.601 ± 0.011 0.601 \pm 0.011 0.601±0.011 Gradient reversal 0.771 ± 0.018 0.771 \pm 0.018 0.771±0.018 0.730 ± 0.020 0.730 \pm 0.020 0.730±0.020 0.739 0.739 0.739 Domain confusion 0.791 ± 0.005 0.791 \pm 0.005 0.791±0.005 0.665 ± 0.033 0.665 \pm 0.033 0.665±0.033 0.681 ± 0.003 0.681 \pm 0.003 0.681±0.003 CoGAN 0.912 ± 0.00 8 ⋆ 0.912 \pm 0.008^\star 0.912±0.008⋆ 0.891 ± 0.008 0.891 \pm 0.008 0.891±0.008 未收敛 ADDA 0.894 ± 0.002 0.894 \pm 0.002 0.894±0.002 0.901 ± 0.00 8 ⋆ 0.901 \pm 0.008^\star 0.901±0.008⋆ 0.760 ± 0.01 8 ⋆ 0.760 \pm 0.018^\star 0.760±0.018⋆
4.2 跨视觉域的自适应学习任务
-
数据:NYU depth 数据集,19 个目标类别,共 1449 张室内图片,分别划分为三个子集,然后在对每一张图片的目标裁剪出实例:
- 训练集(源域 RGB):381 张图片,裁剪出 2186 个图片
- 验证集(目标域 HHA):414 张图片,裁剪出 2401 个图片
- 测试集:654 张图片,
-
迁移任务:RGB → \to → HHA
-
方法及其设定:
- 基础模型 VGG-16:在 ImageNet 上进行预训练
- 源域上的精调:20000 次迭代,批量大小 128
- ADDA 判别器:1024 FC + ReLU + 2048 FC + ReLU + 输出
- 优化参数:20000 迭代,其它设置与数字数据集训练设置一样。
- 基础模型 VGG-16:在 ImageNet 上进行预训练
-
实验结果( ⋆ ^\star ⋆ 表示性能显著提升, △ ^\triangle △ 表示性能下降):验证集总样本数 2401,源域数据训练性能 0.139 0.139 0.139,ADDA 学习的性能 0.211 0.211 0.211,目标域数据训练性能 0.468 0.468 0.468。
bathtub bed bookshelf box chair 样本数 19 96 87 210 611 仅源域 0.000 0.010 0.011 0.124 0.188 ADDA 0.000 0.146 0.046 0.229 ⋆ ^\star ⋆ 0.344 ⋆ ^\star ⋆ 仅目标域 0.105 0.531 0.494 0.295 0.619 counter desk door dresser garbage bin 样本数 103 122 129 25 55 仅源域 0.029 0.041 0.047 0.000 0.000 ADDA 0.447 ⋆ ^\star ⋆ 0.025 △ ^\triangle △ 0.023 △ ^\triangle △ 0.000 0.018 仅目标域 0.573 0.057 0.636 0.120 0.291 lamp monitor night stand pillow sink 样本数 144 37 51 276 47 仅源域 0.069 0.000 0.039 0.587 0.000 ADDA 0.292 ⋆ ^\star ⋆ 0.081 0.020 △ ^\triangle △ 0.297 △ ^\triangle △ 0.021 仅目标域 0.576 0.189 0.235 0.630 0.362 sofa table television toilet 整体性能 样本数 129 210 33 17 2401 仅源域 0.008 0.010 0.000 0.000 0.139 ADDA 0.116 ⋆ ^\star ⋆ 0.143 ⋆ ^\star ⋆ 0.091 0.000 0.211 ⋆ ^\star ⋆ 仅目标域 0.248 0.357 0.303 0.647 0.468
4.3 跨模态的自适应学习任务
-
数据:Office 数据集,分别来自三个领域 amazon (A), webcam (W), dslr (D),共 4110 张图片。
-
迁移任务:A → \to → W, D → \to → W,
-
方法及其设定:
- 基础模型:AlexNet(仅对比) 或 ResNet-50(精调过程不包含 conv5 及以上结构)
- ADDA 判别器:1024 FC + ReLU + 2048 FC + ReLU + 3072 FC + ReLU + 输出
- 优化过程:SGD,2000 次迭代,学习速率 0.001,动量 0.9,批量大小 64
-
实验结果( ⋆ ^\star ⋆ 表示性能最优):ADDA 在所有迁移任务上获得了一致的提升效果。
方法 A → \to → W D → \to → W W → \to → D AlexNet 仅源域 0.642 0.961 0.978 ResNet-50 仅源域 0.626 0.961 0.986 DDC 0.618 0.950 0.985 DAN 0.685 0.960 0.990 DRCN 0.687 0.964 0.990 DANN 0.730 0.964 0.992 ADDA 0.751 ⋆ ^\star ⋆ 0.970 ⋆ ^\star ⋆ 0.996 ⋆ ^\star ⋆
参考文献
[1] Eric Tzeng, Judy Hoffman, Ning Zhang, Kate Saenko, and Trevor Darrell. Deep domain confusion: Maxi- mizing for domain invariance. CoRR, abs/1412.3474, 2014. 2, 8
[2] Mingsheng Long and Jianmin Wang. Learning transfer- able features with deep adaptation networks. Interna- tional Conference on Machine Learning (ICML), 2015. 2, 8
[3] Muhammad Ghifary, W Bastiaan Kleijn, Mengjie Zhang, David Balduzzi, and Wen Li. Deep reconstruction-classification networks for unsupervised domain adaptation. In European Conference on Com- puter Vision (ECCV), pages 597–613. Springer, 2016. 2, 8
[4] Yaroslav Ganin, Evgeniya Ustinova, Hana Ajakan, Pas- cal Germain, Hugo Larochelle, Franc¸ois Laviolette, Mario Marchand, and Victor Lempitsky. Domain- adversarial training of neural networks. Journal of Machine Learning Research, 17(59):1–35, 2016. 4, 5, 6, 7, 8
作者:王瑞 同济大学 计算机系博士研究生
邮箱:rwang@tongji.edu.cn
CSDN:https://blog.csdn.net/i_love_home
Github:https://github.com/mechanicalsea
欢迎大家的提问~