3. Proposed Method
3.1 Limitation of Skip Connections in AttGAN
StarGAN and AttGAN adopt encoder-decoder structure, where spatial pooling or downsampling are essential to obtain high level abstract representation for attribute manipulation. Unfortunately, downsampling irreversibly diminishes spatial resolution and fine details of feature map, which cannot be completely recovered by transposed convolutions and the results are prone to blurring or missing details.
StarGAN和AttGAN使用encoder-decoder结构,其中downsampling操作会损失空间上的细节信息,并且无法通过反卷积来恢复,因此生成图像往往会模糊
AttGAN在encoder和decoder之间增加skip connection,但作用仍然有限,作者并没有从理论上进行分析,而是通过实验验证了skip connection的局限性
考虑AttGAN的4个版本
- AttGAN-ED:不使用skip connection
- AttGAN:官方版本,使用1个skip connection
- AttGAN-2s:使用2个skip connection
- AttGAN-UNet:所有层都使用skip connection,相当于UNet
在某个人脸图像数据集上,令target attribute vector等于source attribute vector,进行人脸重构的任务,Table 1列举了重构的2个指标(PSNR/SSIM),Figure 3展示了重构的结果,可以看到skip connection确实使得人脸重构的效果变好了
现在进行另一个任务,首先在CelebA数据集上训练了一个识别13种attribute的分类器,平均正确率为94.5%,然后生成带有新的attribute的图像,将图像交给attribute分类器去识别,看能不能识别出新加的attribute,从而计算出attribute generate accuracy
Figure 3展示了4个模型的attribute generation accuracy,可以看到skip connection加得越多,attribute generation accuracy越低
综合以上的结论,增加skip connection,重构的图像质量确实会变好,但生成attribure的能力却变差了
3.2 Taking Difference Attribute Vector as Input
定义 att s \text{att}_s atts为source attribute vector, att t \text{att}_t attt为target attribute vector
仅考虑source attribute vector和target attribute vector之间的差值,有3点好处
att
d
i
f
f
=
att
t
−
att
s
(
1
)
\text{att}_{diff} = \text{att}_t - \text{att}_s \qquad(1)
attdiff=attt−atts(1)
- 差值表示更简单,使得网络更容易训练
- 差值包含了哪些attribute需要/不需要编辑,attribute编辑的direction信息
- 差值更容易被用户提供,用户只需要指定改变哪些属性,以及改变的方向即可
3.3 Selective Transfer Units
作者提出一种更高级的skip connection,称为Selective Transfer Units,模型框架图如Figure 5所示
STU是在GRU的基础上进行改进
公式(2)~(7)
3.4 Network Architecture
STGAN包含2个网络:generator G G G、discriminator D D D
其中 G G G包括encoder G e n c G_{enc} Genc和decoder G d e c G_{dec} Gdec, D D D包括判别网络 D a d v D_{adv} Dadv和属性分类网络 D a t t D_{att} Datt
G e n c G_{enc} Genc包含5层卷积层(kernel_size=4,stride=2),因此Figure 5中输入图像之后共有5个立方体
3.5 Loss Functions
f
=
G
e
n
c
(
x
)
(
8
)
\mathbf{f}=G_{enc}(\mathbf{x}) \qquad(8)
f=Genc(x)(8)
其中
f
=
{
f
e
n
c
1
,
⋯
,
f
e
n
c
5
}
\mathbf{f}=\left \{ \mathbf{f}_{enc}^1, \cdots, \mathbf{f}_{enc}^5 \right \}
f={fenc1,⋯,fenc5}
4个STU单元的运算如下
(
f
t
l
,
s
l
)
=
G
s
t
l
(
f
e
n
c
l
,
s
l
+
1
,
a
t
t
d
i
f
f
)
(
9
)
\left ( \mathbf{f}_t^l, s^l \right )=G_{st}^l \left ( \mathbf{f}_{enc}^l, s^{l+1}, \mathbf{att}_{diff} \right ) \qquad(9)
(ftl,sl)=Gstl(fencl,sl+1,attdiff)(9)
上述公式可以理解为
f
=
{
f
e
n
c
1
,
⋯
,
f
e
n
c
5
}
\mathbf{f}=\left \{ \mathbf{f}_{enc}^1, \cdots, \mathbf{f}_{enc}^5 \right \}
f={fenc1,⋯,fenc5}经过STU后变换为
f
t
=
{
f
t
1
,
⋯
,
f
t
4
}
\mathbf{f}_t=\left \{ \mathbf{f}_t^1, \cdots, \mathbf{f}_t^4 \right \}
ft={ft1,⋯,ft4},再将
f
t
\mathbf{f}_t
ft和
f
e
n
c
5
\mathbf{f}_{enc}^5
fenc5送入
G
d
e
c
G_{dec}
Gdec中用于生成图像
y
^
\hat{\mathbf{y}}
y^,即
y
^
=
G
d
e
c
(
f
e
n
c
5
,
f
t
)
(
10
)
\hat{\mathbf{y}}=G_{dec}\left ( \mathbf{f}_{enc}^5, \mathbf{f}_t \right ) \qquad(10)
y^=Gdec(fenc5,ft)(10)
综合公式(8)~(10),有
y
^
=
G
(
x
,
a
t
t
d
i
f
f
)
(
11
)
\hat{\mathbf{y}}=G\left ( \mathbf{x}, \mathbf{att}_{diff} \right ) \qquad(11)
y^=G(x,attdiff)(11)
Reconstruction loss
当
a
t
t
d
i
f
f
=
0
\mathbf{att}_{diff}=\mathbf{0}
attdiff=0时,生成图像(重构图像)应该与输入图像相等,于是定义reconstruction loss如下
L
r
e
c
=
∥
x
−
G
(
x
,
0
)
∥
1
(
12
)
\mathcal{L}_{rec}=\left \| \mathbf{x} - G(\mathbf{x}, \mathbf{0}) \right \|_1 \qquad(12)
Lrec=∥x−G(x,0)∥1(12)
使用
ℓ
1
\ell_1
ℓ1-norm
∥
⋅
∥
1
\left \| \cdot \right \|_1
∥⋅∥1来保证重构图像的sharpness
Adversarial loss
当 a t t d i f f ≠ 0 \mathbf{att}_{diff}\neq\mathbf{0} attdiff=0时,生成图像的ground-truth未知,因此只能使用adversarial loss
本文使用WGAN-GP版本的adversarial loss,分别定义
D
a
d
v
D_{adv}
Dadv和
G
G
G的loss如下
max
D
a
d
v
L
D
a
d
v
=
E
x
D
a
d
v
(
x
)
−
E
y
^
D
a
d
v
(
y
^
)
+
λ
E
x
^
(
∥
∇
x
^
D
a
d
v
(
x
^
)
∥
2
−
1
)
2
(
13
)
\begin{aligned} \underset{D_{adv}}{\max}\ \mathcal{L}_{D_{adv}} =&\mathbb{E}_\mathbf{x}D_{adv}(\mathbf{x})-\mathbb{E}_\mathbf{\hat{y}}D_{adv}(\mathbf{\hat{y}}) +\\ &\lambda\mathbb{E}_\mathbf{\hat{x}}\left ( \left \| \nabla_\mathbf{\hat{x}}D_{adv}\left ( \mathbf{\hat{x}} \right ) \right \|_2 - 1 \right )^2 \qquad(13) \end{aligned}
Dadvmax LDadv=ExDadv(x)−Ey^Dadv(y^)+λEx^(∥∇x^Dadv(x^)∥2−1)2(13)
max
G
L
G
a
d
v
=
E
x
,
a
t
t
d
i
f
f
D
a
d
v
(
G
(
x
,
a
t
t
d
i
f
f
)
)
(
14
)
\underset{G}{\max}\ \mathcal{L}_{G_{adv}}=\mathbb{E}_{\mathbf{x},\mathbf{att}_{diff}}D_{adv}\left ( G\left ( \mathbf{x},\mathbf{att}_{diff} \right ) \right ) \qquad(14)
Gmax LGadv=Ex,attdiffDadv(G(x,attdiff))(14)
其中
x
^
\hat{\mathbf{x}}
x^ is sampled along lines between pairs of real and generated images
Attribute manipulation loss
引入一个attribute classifier
D
a
t
t
D_{att}
Datt,与
D
a
d
v
D_{adv}
Dadv共享卷积部分的layer
分别定义
D
a
d
v
D_{adv}
Dadv和
G
G
G的attribute manipulation loss如下
L
D
a
t
t
=
−
∑
i
=
1
c
[
a
t
t
s
(
i
)
log
D
a
t
t
(
i
)
(
x
)
+
(
1
−
a
t
t
s
(
i
)
)
log
(
1
−
D
a
t
t
(
i
)
(
x
)
)
]
(
15
)
\begin{aligned} \mathcal{L}_{D_{att}}=-\sum_{i=1}^{c}\Big [&\mathbf{att}_s^{(i)}\log D_{att}^{(i)}(\mathbf{x})+\\ &\left ( 1-\mathbf{att}_s^{(i)} \right )\log\left ( 1-D_{att}^{(i)}(\mathbf{x}) \right ) \Big ] \qquad(15) \end{aligned}
LDatt=−i=1∑c[atts(i)logDatt(i)(x)+(1−atts(i))log(1−Datt(i)(x))](15)
L
G
a
t
t
=
−
∑
i
=
1
c
[
a
t
t
t
(
i
)
log
D
a
t
t
(
i
)
(
y
^
)
+
(
1
−
a
t
t
t
(
i
)
)
log
(
1
−
D
a
t
t
(
i
)
(
y
^
)
)
]
(
16
)
\begin{aligned} \mathcal{L}_{G_{att}}=-\sum_{i=1}^{c}\Big [&\mathbf{att}_t^{(i)}\log D_{att}^{(i)}(\mathbf{\hat{y}})+\\ &\left ( 1-\mathbf{att}_t^{(i)} \right )\log\left ( 1-D_{att}^{(i)}(\mathbf{\hat{y}}) \right ) \Big ] \qquad(16) \end{aligned}
LGatt=−i=1∑c[attt(i)logDatt(i)(y^)+(1−attt(i))log(1−Datt(i)(y^))](16)
上标
(
i
)
^{(i)}
(i)表示属性的第
i
i
i个分量,共有
c
c
c个属性
Model Objective
D
D
D和
G
G
G的目标函数分别如下
min
D
L
D
=
−
L
D
a
d
v
+
λ
1
L
D
a
t
t
(
17
)
\underset{D}{\min}\ \mathcal{L}_D=-\mathcal{L}_{D_{adv}}+\lambda_1\mathcal{L}_{D_{att}} \qquad(17)
Dmin LD=−LDadv+λ1LDatt(17)
min
G
L
G
=
−
L
D
a
d
v
+
λ
2
L
D
a
t
t
+
λ
3
L
r
e
c
(
18
)
\underset{G}{\min}\ \mathcal{L}_G=-\mathcal{L}_{D_{adv}}+\lambda_2\mathcal{L}_{D_{att}}+\lambda_3\mathcal{L}_{rec} \qquad(18)
Gmin LG=−LDadv+λ2LDatt+λ3Lrec(18)
实验中设置
λ
1
=
1
\lambda_1=1
λ1=1,
λ
2
=
10
\lambda_2=10
λ2=10,
λ
3
=
100
\lambda_3=100
λ3=100