(AAAI2020)Adversarial Domain Adaptation with Domain Mixup论文笔记

(AAAI2020)Adversarial Domain Adaptation with Domain Mixup论文笔记

文章链接
代码链接

基于对抗方式的对齐源域数据和目标域数据。本文提出一种mixup的对齐方式。传统的对齐方式中,直接在对齐源域和目标域的特征分布。而该方法中,构造出一些minup的数据,相当于在源域的特征和目标域的特征之间建立桥梁,逐渐对齐特征。

在这里插入图片描述

图中表达的意思是,该方法将构造出一些mixup的图片,通过这些图片建立源域和目标域之间的联系,从而对齐源域的特征和目标域的特征。

模型结构

在这里插入图片描述

对于输入数据 x s x^s xs x t x^t xt,我们将其mixup,合并的权值使用 λ \lambda λ,即

x m = λ x s + ( 1 − λ ) x t x^m=\lambda x^s + (1-\lambda)x^t xm=λxs+(1λ)xt

将源域和目标域数据输入编码器 N e N_e Ne,得到两个特征向量 μ \mu μ σ \sigma σ。源域的命名为 μ s \mu_s μs σ s \sigma_s σs,目标域的命名为 μ s \mu_s μs σ t \sigma_t σt

命名时候感觉其中一个代表均值,另一个代表方差。但在代码中实现的时候, N e N_e Ne是卷积层后接全连接层, μ \mu μ σ \sigma σ公用卷积层,卷积层后接两个全连接层分别输出 μ \mu μ σ \sigma σ

之后我们估计mixup图片的 μ \mu μ σ \sigma σ,估计方法同样使用 λ \lambda λ加权

μ m = λ μ s + ( 1 − λ ) μ t \mu_m=\lambda\mu_s+(1-\lambda)\mu_t μm=λμs+(1λ)μt

σ m = λ σ s + ( 1 − λ ) σ t \sigma_m=\lambda\sigma_s+(1-\lambda)\sigma_t σm=λσs+(1λ)σt

之后拼接起来作为解码器 N d N_d Nd的输入,解码器的输入包括 [ μ , σ , z , l c l s , l c o m p ] [\mu,\sigma,z,l_{cls},l_{comp}] [μ,σ,z,lcls,lcomp]

z z z是一个噪声向量, l c l s l_{cls} lcls代表类别向量, l c o m p l_{comp} lcomp代表域类别向量。源域,目标域和mixup的标签信息为

源域: l c l s s = [ 0 , 0... , 1 , . . . , 0 ] , l c o m p s = 0 l_{cls}^s=[0,0...,1,...,0],l_{comp}^s=0 lclss=[0,0...,1,...,0]lcomps=0

目标域: l c l s t = [ 0 , 0... , 0 , . . . , 0 ] , l c o m p s = 1 l_{cls}^t=[0,0...,0,...,0],l_{comp}^s=1 lclst=[0,0...,0,...,0]lcomps=1

mixup图片: l c l s m = [ 0 , 0... , λ , . . . , 0 ] , l c o m p s = 1 − λ l_{cls}^m=[0,0...,\lambda,...,0],l_{comp}^s=1-\lambda lclsm=[0,0...,λ,...,0]lcomps=1λ

N d N_d Nd可以认为成一个条件生成网络,目的是产生出图片。

之后通过计算损失函数来优化参数。

损失函数

首先对于 μ \mu μ σ \sigma σ,我们希望数据的分布和 N ( 0 , 1 ) N(0,1) N(0,1)对齐,所以使用KL散度对齐,损失函数为

m i n N e L K L = D K L ( N ( μ , σ ) ∣ ∣ N ( 0 , I ) ) min_{N_e}L_{KL}=D_{KL}(N(\mu,\sigma)||N(0,I)) minNeLKL=DKL(N(μ,σ)N(0,I))

这步损失函数表明,虽然都是卷积层+全连接层的输出。但 μ \mu μ σ \sigma σ在意义上是存在区别的。

之后和传统的对抗方法类似,我们让判别器D和编码器解码器进行对抗,损失函数为

min ⁡ N e , N d max ⁡ D L a d v s + L a d v t + L a d v m \min\limits_{N_e,N_d}\max\limits_{D}L^s_{adv}+L^t_{adv}+L^m_{adv} Ne,NdminDmaxLadvs+Ladvt+Ladvm

其中

L a d v s = E x s ∼ P s l o g ( D d o m ( x s ) ) + l o g ( 1 − D d o m ( x g s ) ) L^s_{adv}=E_{x^s\sim P_s}log(D_{dom}(x^s))+log(1-D_{dom}(x^s_g)) Ladvs=ExsPslog(Ddom(xs))+log(1Ddom(xgs))

L a d v t = E x t ∼ P t l o g ( 1 − D d o m ( x g t ) ) L^t_{adv}=E_{x^t\sim P_t}log(1-D_{dom}(x_g^t)) Ladvt=ExtPtlog(1Ddom(xgt))

L a d v m = E x s ∼ P s , x t ∼ P t l o g ( 1 − D d o m ( x g m ) ) L_{adv}^m=E_{x^s\sim P_s,x^t\sim P_t}log(1-D_{dom}(x_g^m)) Ladvm=ExsPs,xtPtlog(1Ddom(xgm))

式子中的 x s , x t x^s,x^t xs,xt表示源域和目标域的原始图片, x g s , x g t , x g m x^s_g,x^t_g,x^m_g xgs,xgt,xgm分别表示源域的生成图片,目标域的生成图片和mixup的生成图片。

  1. 损失函数中的 L a d v s L^s_{adv} Ladvs D D D的目标是区分出 x s x_s xs x g s x_g^s xgs的区别,而 N e , N d N_e,N_d Ne,Nd的目标是混淆 x s x_s xs x g s x_g^s xgs。我们可以将 x g s x_g^s xgs看成另一个领域的数据,起初, x s x_s xs x g s x_g^s xgs是两个不同的域的数据,我们希望可以将 x s x_s xs x g s x_g^s xgs对齐,与 x s x_s xs和$x_t对齐类似。所以这个损失函数和传统的对抗方法的域判别损失类似。
  2. 损失函数中的 L a d v t L^t_{adv} Ladvt是类似于传统对抗方法中的域判别损失的一部分。将其中的 E x t ∼ P t l o g ( 1 − D d o m ( x t ) ) E_{x^t\sim P_t}log(1-D_{dom}(x^t)) ExtPtlog(1Ddom(xt))换成了 E x t ∼ P t l o g ( 1 − D d o m ( x g t ) ) E_{x^t\sim P_t}log(1-D_{dom}(x_g^t)) ExtPtlog(1Ddom(xgt))。将原始图片更换成了生成图片
  3. 第三个式子的作用是将mixup的生成图片与源域对齐。

个人认为,这三个式子的目的是逐渐将目标域的图片,mixup​的图片与源域的图片对齐。本文中的判别器 D D D和之前的判别器不同。这个判别器包括特征提取功能和判别功能(否则直接将图片输入判别器,由于特征是low-level的特征,很难进行判别和对齐)。代码中的 D D D的网络是卷积层+全连接层+sigmoid层实现。

综合一下上述的损失函数可以发现, N e , N d N_e,N_d Ne,Nd的目的是对图片解码编码后,生成的图片和源域图片对齐。这部分损失函数并没有使用到 x t x^t xt x m x^m xm。后续的soft labeltriplet loss将会用到。

soft label 损失为

min ⁡ D L s o f t m = − E x s ∼ P s , x t ∼ P t l d o m m l o g ( D d o m ( x m ) ) + ( 1 − l d o m m ) l o g ( 1 − D d o m ( x m ) ) \min\limits_D{L^m_{soft}}=-E_{x^s\sim P_s,x^t\sim P_t}l_{dom}^mlog(D_{dom}(x^m))+(1-l_{dom}^m)log(1-D_{dom}(x^m)) DminLsoftm=ExsPs,xtPtldommlog(Ddom(xm))+(1ldomm)log(1Ddom(xm))

其中 l d o m m l_{dom}^m ldomm表示mixup图像的领域标签,即 λ \lambda λ

这个的作用是希望 D D D对于mixup图像,可以输出其领域标签信息为 λ \lambda λ

triplet loss

triplet loss中包含三类样本, ( a , p , n ) (a,p,n) (a,p,n),分别表示取出的样本,同类的样本和非同类的样本。本文中的triplet loss并不是针对类别层面,而是针对领域层面。

如果mixup的样本的 λ ≥ 0.5 \lambda\geq 0.5 λ0.5,说明这类样本更接近源域,那么 ( a , p , n ) = ( x m , x s , x t ) (a,p,n)=(x^m,x^s,x^t) (a,p,n)=(xm,xs,xt)

如果mixup的样本的 λ < 0.5 \lambda < 0.5 λ<0.5,说明这类样本更接近目标域,那么 ( a , p , n ) = ( x m , x t , x s ) (a,p,n)=(x^m,x^t,x^s) (a,p,n)=(xm,xt,xs)

之后计算triplet losstriplet loss中的偏置设定为 ∣ 2 λ − 1 ∣ |2\lambda-1| 2λ1

(代码中的 λ \lambda λ按照 β \beta β分布随机生成,但如果 λ \lambda λ的值比较靠近 0.5 0.5 0.5,就会修改到稍微远离 0.5 0.5 0.5

和之前很多方法类似,我们让判别器 D D D拥有分类的能力,这里分类针对的是源域和目标域数据的生成图像。不同的是这里的判别器 D D D包括特征提取功能。所以我们只需要在卷积层后加入全连接层用于分类。

min ⁡ N e , N d , D L c l s s + L c l s t \min\limits_{N_e,N_d,D}L_{cls}^s+L_{cls}^t Ne,Nd,DminLclss+Lclst

都是交叉熵损失,目标域的标签使用分类器 C C C给出的伪标签。

最后还有个分类器 C C C的优化损失函数,分类损失 min ⁡ N e , C L C \min\limits_{N_e,C}L_C Ne,CminLC

文章效果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

本文的创新点在于引入mixup图像的方法,让源域和目标域的对齐不那么直接,而是通过mixup作为桥梁连接源域和目标域的对齐过程。

但这篇文章中并没有提到对齐条件概率分布(不知道是不是弄漏了)。没有在类别层面的对齐,效果依然很好,很奇怪,如果说给 D D D加入分类效果,其中用上了 C C C给的伪标签,这部分的对齐效果有这么好么?

在本文的主体部分,mixup图像以及生产图像的对齐上,并没有对齐条件概率分布的对齐。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值