AttnGAN
摘要
-
本文提出了一种注意力生成对抗网络(AttnGAN),该网络通过注意力驱动的多阶段细化来实现细粒度的文本到图像生成。借助注意力生成网络,AttnGAN可以通过关注自然语言描述中的相关单词来合成图像不同子区域的细粒度细节。
-
此外,提出了一种注意力集中的多模态相似度模型,以计算出细粒度的图像-文本匹配损失,以训练生成器。
AttenGAN明显优于现有的技术,在CUB数据集上比之前的最优结果提升14.14%,在更具挑战性的COCO数据集上提高了170.25%。
- 还可以通过可视化AttnGAN的注意层来执行详细分析。它首次显示出分层注意力GAN能够在单词级别自动选择条件以生成图像的不同部分。
AttnGAN 结构
DAMSM为生成网络提供了细粒度的图文匹配损失。
Attentional Generative Network
多阶段图像生成:
- 将文本embedding得到word features和sentence features;
- 利用sentence features生成一个低分辨率的图像;
- 在低分辨率的图像上进一步加入word features和sentence features来生成更高分辨率的图像。
低分辨率图像生成
h 0 = F 0 ( z , F c a ( e ˉ ) ) h_{0}=F_{0}\left(z, F^{c a}(\bar{e})\right) h0=F0(z,Fca(eˉ))
z z z是从标准正态分布中采样的随机噪声;
e ˉ \bar e eˉ是全局句子矢量;
同大多数Text2Image模型一样,AttnGAN也是利用随机噪声和全局句子矢量生成图像。
细粒度图像生成
在接下来的 i i i部分内,每一部分都利用AttnGAN来为图像增加细节部分。
h i = F i ( h i − 1 , F i a t t n ( e , h i − 1 ) ) f o r i = 1 , 2 , … , m − 1 h_{i}=F_{i}\left(h_{i-1}, F_{i}^{a t t n}\left(e, h_{i-1}\right)\right) \quad \text for \quad i=1,2, \ldots, m-1 hi=Fi(hi−1,Fiattn(e,hi−1))fori=1,2,…,m−1
x ^ i = G i ( h i ) \hat{x}_{i}=G_{i}\left(h_{i}\right) x^i=Gi(hi)
对输入图片的每一部分,匹配最相关的单词向量来约束其生成,增加图像的细粒度细节。匹配图像子区域和最相关的单词公式如下:
c
j
=
∑
i
=
0
T
−
1
β
j
,
i
e
i
′
,
w
h
e
r
e
β
j
,
i
=
exp
(
s
j
,
i
′
)
∑
k
=
0
T
−
1
exp
(
s
j
,
k
′
)
c_{j}=\sum_{i=0}^{T-1} \beta_{j, i} e_{i}^{\prime}, \quad \text where \quad \beta_{j, i}=\frac{\exp \left(s_{j, i}^{\prime}\right)}{\sum_{k=0}^{T-1} \exp \left(s_{j, k}^{\prime}\right)}
cj=i=0∑T−1βj,iei′,whereβj,i=∑k=0T−1exp(sj,k′)exp(sj,i′)
s j , i ′ = h j T e i ′ s_{j, i}^{\prime}=h_{j}^{T} e_{i}^{\prime} sj,i′=hjTei′
Multi-modal loss
为了生成具有多个级别(即句子级别和单词级别)条件的真实图像,注意力生成网络的最终目标功能定义为:
L
=
L
G
+
λ
L
D
A
M
S
M
,
w
h
e
r
e
L
G
=
∑
i
=
0
m
−
1
L
G
i
\mathcal{L}=\mathcal{L}_{G}+\lambda \mathcal{L}_{D A M S M}, \quad \text where \mathcal{L}_{G}=\sum_{i=0}^{m-1} \mathcal{L}_{G_{i}}
L=LG+λLDAMSM,whereLG=i=0∑m−1LGi
λ \lambda λ是一个超参数,可以平衡方程的两个项;
第一项是GAN损失:
L G i = − 1 2 E x ^ i ∼ p G i [ log ( D i ( x ^ i ) ] ⏟ unconditional loss − 1 2 E x ^ i ∼ p G i [ log ( D i ( x ^ i , e ˉ ) ] ⏟ conditional loss , \mathcal{L}_{G_{i}}=\underbrace{-\frac{1}{2} \mathbb{E}_{\hat{x}_{i} \sim p_{G_{i}}}\left[\log \left(D_{i}\left(\hat{x}_{i}\right)\right]\right.}_{\text {unconditional loss }} \underbrace{-\frac{1}{2} \mathbb{E}_{\hat{x}_{i} \sim p_{G_{i}}}\left[\log \left(D_{i}\left(\hat{x}_{i}, \bar{e}\right)\right]\right.}_{\text {conditional loss }}, LGi=unconditional loss −21Ex^i∼pGi[log(Di(x^i)]conditional loss −21Ex^i∼pGi[log(Di(x^i,eˉ)],
无条件损失确定图像是真实的还是伪造的,而条件损失确定图像和句子是否匹配。
第二项是DAMSM损失
此外,每个判别器 D i D_i Di也经过训练,通过最小化交叉熵损失来判断输入的真伪:
L D i = − 1 2 E x i ∼ p data i [ log D i ( x i ) ] − 1 2 E x ^ i ∼ p G i [ log ( 1 − D i ( x ^ i ) ] ⏟ unconditional loss + − 1 2 E x i ∼ p data i [ log D i ( x i , e ˉ ) ] − 1 2 E x ^ i ∼ p G i [ log ( 1 − D i ( x ^ i , e ˉ ) ] ⏟ conditional loss \begin{aligned} \mathcal{L}_{D_{i}}=& \underbrace{-\frac{1}{2} \mathbb{E}_{x_{i} \sim p_{\text {data }_{i}}}\left[\log D_{i}\left(x_{i}\right)\right]-\frac{1}{2} \mathbb{E}_{\hat{x}_{i} \sim p_{G_{i}}}\left[\log \left(1-D_{i}\left(\hat{x}_{i}\right)\right]\right.}_{\text {unconditional loss }}+\\ & \underbrace{-\frac{1}{2} \mathbb{E}_{x_{i} \sim p_{\text {data }_{i}}}\left[\log D_{i}\left(x_{i}, \bar{e}\right)\right]-\frac{1}{2} \mathbb{E}_{\hat{x}_{i} \sim p_{G_{i}}}\left[\log \left(1-D_{i}\left(\hat{x}_{i}, \bar{e}\right)\right]\right.}_{\text {conditional loss }} \end{aligned} LDi=unconditional loss −21Exi∼pdata i[logDi(xi)]−21Ex^i∼pGi[log(1−Di(x^i)]+conditional loss −21Exi∼pdata i[logDi(xi,eˉ)]−21Ex^i∼pGi[log(1−Di(x^i,eˉ)]
其中, x i x_i xi是来自真实图片, x ^ i \hat x_i x^i是模型生成的数据,AttnGAN的鉴别器在结构上是不相交的,因此可以并行训练它们。
Deep Attentional Multimodal Similarity Model
DAMSM包含了两个神经网络,它们将图像的子区域和句子的单词映射到公共语义空间,从而度量图像-文本相似度,从而计算图像生成的细粒度损失。
文本编码器
是双向的LSTM,用来从文本描述中提取语义向量,
图像编码器
图像编码器是将图像映射到语义向量的卷积神经网络(CNN)。CNN的中间层学习图像的不同子区域的局部特征,而后面的层学习图像的全局特征。
The attention-driven image-text matching score
用于基于图像和文本之间的注意力模型来衡量图像句子对的匹配
-
首先为图像中句子和子区域中所有可能的单词对计算相似度矩阵:
s = e T v s=e^{T} v s=eTv -
归一化处理:
s ˉ i , j = exp ( s i , j ) ∑ k = 0 T − 1 exp ( s k , j ) \bar{s}_{i, j}=\frac{\exp \left(s_{i, j}\right)}{\sum_{k=0}^{T-1} \exp \left(s_{k, j}\right)} sˉi,j=∑k=0T−1exp(sk,j)exp(si,j)
- 建立一个注意力模型来计算每个单词(查询)的区域上下文向量
c i = ∑ j = 0 288 α j v j , w h e r e α j = exp ( γ 1 s ˉ i , j ) ∑ k = 0 288 exp ( γ 1 s ˉ i , k ) c_{i}=\sum_{j=0}^{288} \alpha_{j} v_{j}, \quad \text where \quad \alpha_{j}=\frac{\exp \left(\gamma_{1} \bar{s}_{i, j}\right)}{\sum_{k=0}^{288} \exp \left(\gamma_{1} \bar{s}_{i, k}\right)} ci=j=0∑288αjvj,whereαj=∑k=0288exp(γ1sˉi,k)exp(γ1sˉi,j)
- 使用余弦相似度来定义单词和图像之间的相关性
R ( c i , e i ) = ( c i T e i ) / ( ∥ c i ∥ ∥ e i ∥ ) R\left(c_{i}, e_{i}\right)=\left(c_{i}^{T} e_{i}\right) /\left(\left\|c_{i}\right\|\left\|e_{i}\right\|\right) R(ci,ei)=(ciTei)/(∥ci∥∥ei∥)
R ( Q , D ) = log ( ∑ i = 1 T − 1 exp ( γ 2 R ( c i , e i ) ) ) 1 γ 2 R(Q, D)=\log \left(\sum_{i=1}^{T-1} \exp \left(\gamma_{2} R\left(c_{i}, e_{i}\right)\right)\right)^{\frac{1}{\gamma_{2}}} R(Q,D)=log(i=1∑T−1exp(γ2R(ci,ei)))γ21
The DAMSM loss
被设计为以半监督方式学习注意力模型,其中唯一的监督是整个图像和整体(单词序列)之间的匹配。
对图像-句子对
{
(
Q
i
,
D
i
)
}
i
=
1
M
\left\{\left(Q_{i}, D_{i}\right)\right\}_{i=1}^{M}
{(Qi,Di)}i=1M,句子
D
i
D_i
Di和图像
Q
i
Q_i
Qi匹配的后验概率被计算为:
P
(
D
i
∣
Q
i
)
=
exp
(
γ
3
R
(
Q
i
,
D
i
)
)
∑
j
=
1
M
exp
(
γ
3
R
(
Q
i
,
D
j
)
)
P\left(D_{i} \mid Q_{i}\right)=\frac{\exp \left(\gamma_{3} R\left(Q_{i}, D_{i}\right)\right)}{\sum_{j=1}^{M} \exp \left(\gamma_{3} R\left(Q_{i}, D_{j}\right)\right)}
P(Di∣Qi)=∑j=1Mexp(γ3R(Qi,Dj))exp(γ3R(Qi,Di))
将损失函数定义为图像与其对应的文本描述(groundtruth)匹配的负对数后验概率
L
1
w
=
−
∑
i
=
1
M
log
P
(
D
i
∣
Q
i
)
\mathcal{L}_{1}^{w}=-\sum_{i=1}^{M} \log P\left(D_{i} \mid Q_{i}\right)
L1w=−i=1∑MlogP(Di∣Qi)
最后,DAMSM损失定义为:
L
D
A
M
S
M
=
L
1
w
+
L
2
w
+
L
1
s
+
L
2
s
\mathcal{L}_{D A M S M}=\mathcal{L}_{1}^{w}+\mathcal{L}_{2}^{w}+\mathcal{L}_{1}^{s}+\mathcal{L}_{2}^{s}
LDAMSM=L1w+L2w+L1s+L2s
数据集
本次实验使用的数据集是加利福尼亚理工学院鸟类薮据库2011(CUB_2002011)
实验结果
源自论文: