STGAN: A Unified Selective Transfer Network for Arbitrary Image Attribute Editing(CVPR19)

3. Proposed Method

3.1 Limitation of Skip Connections in AttGAN

StarGAN and AttGAN adopt encoder-decoder structure, where spatial pooling or downsampling are essential to obtain high level abstract representation for attribute manipulation. Unfortunately, downsampling irreversibly diminishes spatial resolution and fine details of feature map, which cannot be completely recovered by transposed convolutions and the results are prone to blurring or missing details.

StarGAN和AttGAN使用encoder-decoder结构,其中downsampling操作会损失空间上的细节信息,并且无法通过反卷积来恢复,因此生成图像往往会模糊

AttGAN在encoder和decoder之间增加skip connection,但作用仍然有限,作者并没有从理论上进行分析,而是通过实验验证了skip connection的局限性

考虑AttGAN的4个版本

  1. AttGAN-ED:不使用skip connection
  2. AttGAN:官方版本,使用1个skip connection
  3. AttGAN-2s:使用2个skip connection
  4. AttGAN-UNet:所有层都使用skip connection,相当于UNet

在某个人脸图像数据集上,令target attribute vector等于source attribute vector,进行人脸重构的任务,Table 1列举了重构的2个指标(PSNR/SSIM),Figure 3展示了重构的结果,可以看到skip connection确实使得人脸重构的效果变好了
在这里插入图片描述
在这里插入图片描述
现在进行另一个任务,首先在CelebA数据集上训练了一个识别13种attribute的分类器,平均正确率为94.5%,然后生成带有新的attribute的图像,将图像交给attribute分类器去识别,看能不能识别出新加的attribute,从而计算出attribute generate accuracy
在这里插入图片描述
Figure 3展示了4个模型的attribute generation accuracy,可以看到skip connection加得越多,attribute generation accuracy越低

综合以上的结论,增加skip connection,重构的图像质量确实会变好,但生成attribure的能力却变差了

3.2 Taking Difference Attribute Vector as Input

定义 att s \text{att}_s atts为source attribute vector, att t \text{att}_t attt为target attribute vector

仅考虑source attribute vector和target attribute vector之间的差值,有3点好处
att d i f f = att t − att s ( 1 ) \text{att}_{diff} = \text{att}_t - \text{att}_s \qquad(1) attdiff=atttatts(1)

  1. 差值表示更简单,使得网络更容易训练
  2. 差值包含了哪些attribute需要/不需要编辑,attribute编辑的direction信息
  3. 差值更容易被用户提供,用户只需要指定改变哪些属性,以及改变的方向即可
3.3 Selective Transfer Units

作者提出一种更高级的skip connection,称为Selective Transfer Units,模型框架图如Figure 5所示
在这里插入图片描述
STU是在GRU的基础上进行改进
公式(2)~(7)

3.4 Network Architecture

STGAN包含2个网络:generator G G G、discriminator D D D

其中 G G G包括encoder G e n c G_{enc} Genc和decoder G d e c G_{dec} Gdec D D D包括判别网络 D a d v D_{adv} Dadv和属性分类网络 D a t t D_{att} Datt

G e n c G_{enc} Genc包含5层卷积层(kernel_size=4,stride=2),因此Figure 5中输入图像之后共有5个立方体

3.5 Loss Functions

f = G e n c ( x ) ( 8 ) \mathbf{f}=G_{enc}(\mathbf{x}) \qquad(8) f=Genc(x)(8)
其中 f = { f e n c 1 , ⋯   , f e n c 5 } \mathbf{f}=\left \{ \mathbf{f}_{enc}^1, \cdots, \mathbf{f}_{enc}^5 \right \} f={fenc1,,fenc5}

4个STU单元的运算如下
( f t l , s l ) = G s t l ( f e n c l , s l + 1 , a t t d i f f ) ( 9 ) \left ( \mathbf{f}_t^l, s^l \right )=G_{st}^l \left ( \mathbf{f}_{enc}^l, s^{l+1}, \mathbf{att}_{diff} \right ) \qquad(9) (ftl,sl)=Gstl(fencl,sl+1,attdiff)(9)

上述公式可以理解为 f = { f e n c 1 , ⋯   , f e n c 5 } \mathbf{f}=\left \{ \mathbf{f}_{enc}^1, \cdots, \mathbf{f}_{enc}^5 \right \} f={fenc1,,fenc5}经过STU后变换为 f t = { f t 1 , ⋯   , f t 4 } \mathbf{f}_t=\left \{ \mathbf{f}_t^1, \cdots, \mathbf{f}_t^4 \right \} ft={ft1,,ft4},再将 f t \mathbf{f}_t ft f e n c 5 \mathbf{f}_{enc}^5 fenc5送入 G d e c G_{dec} Gdec中用于生成图像 y ^ \hat{\mathbf{y}} y^,即
y ^ = G d e c ( f e n c 5 , f t ) ( 10 ) \hat{\mathbf{y}}=G_{dec}\left ( \mathbf{f}_{enc}^5, \mathbf{f}_t \right ) \qquad(10) y^=Gdec(fenc5,ft)(10)
综合公式(8)~(10),有
y ^ = G ( x , a t t d i f f ) ( 11 ) \hat{\mathbf{y}}=G\left ( \mathbf{x}, \mathbf{att}_{diff} \right ) \qquad(11) y^=G(x,attdiff)(11)

Reconstruction loss

a t t d i f f = 0 \mathbf{att}_{diff}=\mathbf{0} attdiff=0时,生成图像(重构图像)应该与输入图像相等,于是定义reconstruction loss如下
L r e c = ∥ x − G ( x , 0 ) ∥ 1 ( 12 ) \mathcal{L}_{rec}=\left \| \mathbf{x} - G(\mathbf{x}, \mathbf{0}) \right \|_1 \qquad(12) Lrec=xG(x,0)1(12)
使用 ℓ 1 \ell_1 1-norm ∥ ⋅ ∥ 1 \left \| \cdot \right \|_1 1来保证重构图像的sharpness

Adversarial loss

a t t d i f f ≠ 0 \mathbf{att}_{diff}\neq\mathbf{0} attdiff=0时,生成图像的ground-truth未知,因此只能使用adversarial loss

本文使用WGAN-GP版本的adversarial loss,分别定义 D a d v D_{adv} Dadv G G G的loss如下
max ⁡ D a d v   L D a d v = E x D a d v ( x ) − E y ^ D a d v ( y ^ ) + λ E x ^ ( ∥ ∇ x ^ D a d v ( x ^ ) ∥ 2 − 1 ) 2 ( 13 ) \begin{aligned} \underset{D_{adv}}{\max}\ \mathcal{L}_{D_{adv}} =&\mathbb{E}_\mathbf{x}D_{adv}(\mathbf{x})-\mathbb{E}_\mathbf{\hat{y}}D_{adv}(\mathbf{\hat{y}}) +\\ &\lambda\mathbb{E}_\mathbf{\hat{x}}\left ( \left \| \nabla_\mathbf{\hat{x}}D_{adv}\left ( \mathbf{\hat{x}} \right ) \right \|_2 - 1 \right )^2 \qquad(13) \end{aligned} Dadvmax LDadv=ExDadv(x)Ey^Dadv(y^)+λEx^(x^Dadv(x^)21)2(13)
max ⁡ G   L G a d v = E x , a t t d i f f D a d v ( G ( x , a t t d i f f ) ) ( 14 ) \underset{G}{\max}\ \mathcal{L}_{G_{adv}}=\mathbb{E}_{\mathbf{x},\mathbf{att}_{diff}}D_{adv}\left ( G\left ( \mathbf{x},\mathbf{att}_{diff} \right ) \right ) \qquad(14) Gmax LGadv=Ex,attdiffDadv(G(x,attdiff))(14)
其中 x ^ \hat{\mathbf{x}} x^ is sampled along lines between pairs of real and generated images

Attribute manipulation loss

引入一个attribute classifier D a t t D_{att} Datt,与 D a d v D_{adv} Dadv共享卷积部分的layer
分别定义 D a d v D_{adv} Dadv G G Gattribute manipulation loss如下
L D a t t = − ∑ i = 1 c [ a t t s ( i ) log ⁡ D a t t ( i ) ( x ) + ( 1 − a t t s ( i ) ) log ⁡ ( 1 − D a t t ( i ) ( x ) ) ] ( 15 ) \begin{aligned} \mathcal{L}_{D_{att}}=-\sum_{i=1}^{c}\Big [&\mathbf{att}_s^{(i)}\log D_{att}^{(i)}(\mathbf{x})+\\ &\left ( 1-\mathbf{att}_s^{(i)} \right )\log\left ( 1-D_{att}^{(i)}(\mathbf{x}) \right ) \Big ] \qquad(15) \end{aligned} LDatt=i=1c[atts(i)logDatt(i)(x)+(1atts(i))log(1Datt(i)(x))](15)
L G a t t = − ∑ i = 1 c [ a t t t ( i ) log ⁡ D a t t ( i ) ( y ^ ) + ( 1 − a t t t ( i ) ) log ⁡ ( 1 − D a t t ( i ) ( y ^ ) ) ] ( 16 ) \begin{aligned} \mathcal{L}_{G_{att}}=-\sum_{i=1}^{c}\Big [&\mathbf{att}_t^{(i)}\log D_{att}^{(i)}(\mathbf{\hat{y}})+\\ &\left ( 1-\mathbf{att}_t^{(i)} \right )\log\left ( 1-D_{att}^{(i)}(\mathbf{\hat{y}}) \right ) \Big ] \qquad(16) \end{aligned} LGatt=i=1c[attt(i)logDatt(i)(y^)+(1attt(i))log(1Datt(i)(y^))](16)
上标 ( i ) ^{(i)} (i)表示属性的第 i i i个分量,共有 c c c个属性

Model Objective

D D D G G G的目标函数分别如下
min ⁡ D   L D = − L D a d v + λ 1 L D a t t ( 17 ) \underset{D}{\min}\ \mathcal{L}_D=-\mathcal{L}_{D_{adv}}+\lambda_1\mathcal{L}_{D_{att}} \qquad(17) Dmin LD=LDadv+λ1LDatt(17)
min ⁡ G   L G = − L D a d v + λ 2 L D a t t + λ 3 L r e c ( 18 ) \underset{G}{\min}\ \mathcal{L}_G=-\mathcal{L}_{D_{adv}}+\lambda_2\mathcal{L}_{D_{att}}+\lambda_3\mathcal{L}_{rec} \qquad(18) Gmin LG=LDadv+λ2LDatt+λ3Lrec(18)
实验中设置 λ 1 = 1 \lambda_1=1 λ1=1 λ 2 = 10 \lambda_2=10 λ2=10 λ 3 = 100 \lambda_3=100 λ3=100

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值