Adma-GAN是由浙大学者和腾讯联合提出的一种属性驱动内存增强型GAN,文章被 A类会议ACM Multimedia收录,发表于2022年9月28日。
原文地址:https://arxiv.org/abs/2209.14046
代码地址:https://github.com/Hsintien-Ng/Adma-GAN
本篇文章是阅读Adma-GAN: Attribute-Driven Memory Augmented GANs for Text-to-Image Generation的精读笔记。
一、原文摘要
作为一项具有挑战性的任务,文本到图像生成旨在根据给定的文本描述生成照片级真实感和语义一致的图像。现有的方法主要是从一个句子中提取文本信息来表示图像,文本表示对生成图像的质量影响很大。然而,在一句话中直接利用有限的信息会遗漏一些关键的属性描述,而这些属性描述是准确描述图像的关键因素。为了缓解上述问题,我们提出了一种有效的带有属性信息补充的文本表示方法。首先,我们构造了一个属性存储器来联合控制文本到图像的生成和句子输入。其次,我们探索了两种更新机制,样本感知和样本联合机制,以动态优化广义属性内存。此外,我们还设计了一种属性-句子联合条件生成器学习方案,以使特征嵌入在多个表示之间对齐,从而促进了跨模态网络的训练。实验结果表明,该方法在CUB(FID从14.81到8.57)和COCO(FID由21.42到12.39)数据集上都取得了显著的性能改进。
二、为什么提出Adma-GAN?
- 文本和图像数据之间的模态结构存在较大差距,跨模态文本到图像生成模型的优化容易出现过拟合或塌陷,从而生成不规则的对象形状。
- 现有的句子嵌入方法仅利用一个句子中有限的信息对文本表示进行建模,造成了以下障碍:它遗漏了一些关键属性描述,而这些关键属性描述是准确描述图像的关键因素。
基于此,文章提出了一种有效的基于句子属性信息补足的文本表示方法。并设计了一个属性存储器来联合句子和生成器的合成过程。因此,作者把文本到图像的生成任务看作是一个 属性-句子联合条件生成问题。问题的关键有两方面:
- 如何构造属性存储库;
- 如何学习具有属性和句子联合条件的图像生成器。
三、创新点
- 构造属性存储库,首先收集数据集中所有可能的属性描述作为属性库,并将它们转换为属性内存,然后提取属性的标签组合形成公共属性库,具体来说,作者构造了一个图来表示数据集中的属性相关性,并使用图卷积网络来提取属性特征,获得用于属性驱动条件生成的最佳属性内存。
- 设计了一种属性-句子联合条件生成器学习方案,用于处理多种表示(即句子、属性、图像)之间的转换,使用对比学习增强多个表示之间的语义一致性。在公共空间将图像与句子和属性对齐,属于同一样本的属性图像和句子图像对被拉得更近,而不同样本的对被推得更远。
四、属性驱动内存增强型GAN
4.1、模型结构
模型以DFGAN作为baseline,同样使用单阶段文本生成图像架构:
文本编码部分(图中未给出,4.2节展开)将原始句子转换为句子嵌入并构建一个属性存储器来自预定义属性库。
再主框架,包括一个存储增强型生成器和一个带有辅助分类的条件鉴别器,在生成器生成过程中,使用存储器提取文本的属性特征,将属性特征和句子特征分别插入到不同级别的Up-Block中。鉴别器鉴别过程与DFGAN类似。
4.2、属性内存更新机制
作者采用了两种属性内存更新策略:样本感知和样本联结,以从属性内存中获取当前句子最合适的属性嵌入(最终比较选择样本联结的策略)。
4.2.1、样本感知内存更新机制
为了更新属性内存,作者将内存的所有参数视为可优化参数,并将它们添加到整个生成器的参数组中。因此,可以通过网络的梯度反向传播来实现内存的更新。
如图中y表示,将图像样本用多属性二进制标签进行标注,1表示图像有此类属性,而0表示图像不具有此类属性。
给定一个样本,将图像样本标注的y与从文本编码中经过属性内存器提取的
M
a
M_a
Ma相乘,得到公共属性
e
a
=
y
⋅
M
a
e_{a}=y \cdot M_{a}
ea=y⋅Ma,其中y的维度为1×n,
M
a
M_{a}
Ma的维度为n×d,n表示所有属性的数量,d表示嵌入向量的维数。然后
e
a
e_{a}
ea作为附加条件与句子特征
e
s
e_{s}
es共同引导图像合成。
这样,通过梯度反向传播,当网络更新时,只优化与采样器相关的属性嵌入。但是,预定义属性库中的不同属性具有内在的关联性。这种方法忽略整个数据集中的全局相关模式。
4.2.2、样本联合内存更新机制
引入了基于相关矩阵的图神经网络GCN来建模属性之间的关系和传播信息。
在图神经网络中,属性内存被设置为图的初始节点特征,每个嵌入表示一个节点。给定初始节点特征为
H
0
=
M
a
H^0 =M_a
H0=Ma,相关矩阵为C,GCN通过堆叠可学习转换矩阵更新节点特征W. GCN层的表示为:
H
l
+
1
=
LeakyReLU
(
C
⋅
H
l
⋅
W
l
)
H^{l+1}=\operatorname{LeakyReLU}\left(C \cdot H^{l} \cdot W^{l}\right)
Hl+1=LeakyReLU(C⋅Hl⋅Wl),为了建模属性之间的全局相关性,作者通过计算训练集中属性对的出现次数构造了相关性矩阵C ,表示为
C
i
j
=
{
0
,
if
P
i
j
<
τ
1
,
if
P
i
j
≥
τ
C_{i j}=\left\{\begin{array}{ll} 0, & \text { if } P_{i j}<\tau \\ 1, & \text { if } P_{i j} \geq \tau \end{array}\right.
Cij={0,1, if Pij<τ if Pij≥τ,重新加权缓解二元相关矩阵的过光滑转换为:
C
i
j
′
=
{
p
∑
i
=
1
,
i
≠
j
n
C
i
j
,
if
i
≠
j
1
−
p
,
if
i
=
j
,
C_{i j}^{\prime}=\left\{\begin{array}{ll} \frac{p}{\sum_{i=1, i \neq j}^{n} C_{i j}}, & \text { if } i \neq j \\ 1-p, & \text { if } i=j \end{array},\right.
Cij′={∑i=1,i=jnCijp,1−p, if i=j if i=j,因此
M
a
M_a
Ma就可以通过基于相关矩阵C的GCN不断更新,
e
a
′
=
y
⋅
H
L
e_{a}^{\prime}=y \cdot H^{L}
ea′=y⋅HL,在更新当前样本的属性嵌入时,其他样本的共现属性嵌入也将得到优化,从而获得更有效的属性内存。
4.3、属性-图像对齐
作者引入了一种对比学习损失来对齐属性和图像嵌入到公共空间,形式上,作者采用余弦相似度作为度量标准:
L
c
l
(
u
,
v
)
=
−
1
m
∑
i
=
1
m
log
exp
(
cos
(
u
i
,
v
i
)
/
η
)
∑
j
=
1
m
exp
(
cos
(
u
i
,
v
j
)
/
η
)
\mathcal{L}_{c l}(u, v)=-\frac{1}{m} \sum_{i=1}^{m} \log \frac{\exp \left(\cos \left(u^{i}, v^{i}\right) / \eta\right)}{\sum_{j=1}^{m} \exp \left(\cos \left(u^{i}, v^{j}\right) / \eta\right)}
Lcl(u,v)=−m1∑i=1mlog∑j=1mexp(cos(ui,vj)/η)exp(cos(ui,vi)/η)
相应的计算为:
L
a
t
t
r
−
real
=
L
c
l
(
D
i
m
g
(
x
)
,
e
a
′
)
L
a
t
t
r
−
fake
=
L
c
l
(
D
i
m
g
(
x
f
)
,
e
a
′
)
\begin{aligned} \mathcal{L}_{a t t r_{-} \text {real }} &=\mathcal{L}_{c l}\left(D_{i m g}(x), e_{a}^{\prime}\right) \\ \mathcal{L}_{a t t r_{-} \text {fake }} &=\mathcal{L}_{c l}\left(D_{i m g}\left(x_{f}\right), e_{a}^{\prime}\right) \end{aligned}
Lattr−real Lattr−fake =Lcl(Dimg(x),ea′)=Lcl(Dimg(xf),ea′)
此外,作者还将对比学习同时应用于其他情态对,包括带句子的图像、带相同描述的真实图像的假图像:
L
sentreal
=
L
c
l
(
D
i
m
g
(
x
)
,
e
s
)
L
sentfake
=
L
c
l
(
D
i
m
g
(
x
f
)
,
e
s
)
L
img
=
L
c
l
(
D
i
m
g
(
x
)
,
D
i
m
g
(
x
f
)
)
\begin{array}{l} \mathcal{L}_{\text {sentreal }} = \mathcal{L}_{c l}\left(D_{i m g}(x), e_{s}\right)\\ \mathcal{L}_{\text {sentfake }}=\mathcal{L}_{c l}\left(D_{i m g}\left(x_{f}\right), e_{s}\right) \\ \mathcal{L}_{\text {img }}=\mathcal{L}_{c l}\left(D_{i m g}(x), D_{i m g}\left(x_{f}\right)\right) \end{array}
Lsentreal =Lcl(Dimg(x),es)Lsentfake =Lcl(Dimg(xf),es)Limg =Lcl(Dimg(x),Dimg(xf))
4.4、目标函数
作者通过三类约束来提高所提出的cGAN的能力:1)真实性鉴别;2) 多属性分类;3) 跨模态对齐。
总体损失为:
L
D
=
L
a
d
v
−
D
+
λ
1
L
align
D
+
λ
2
L
c
l
s
−
D
+
λ
3
L
m
a
−
g
p
,
L
G
=
L
a
d
v
−
G
+
λ
4
L
alignG
+
λ
5
L
c
l
s
−
G
\begin{array}{l} \mathcal{L}_{D}=\mathcal{L}_{a d v{-} D}+\lambda_{1} \mathcal{L}_{\text {align }} D+\lambda_{2} \mathcal{L}_{c l s{-} D}+\lambda_{3} \mathcal{L}_{m a-g p}, \\ \mathcal{L}_{G}=\mathcal{L}_{a d v{-} G}+\lambda_{4} \mathcal{L}_{\text {alignG }}+\lambda_{5} \mathcal{L}_{c l s{-} G} \end{array}
LD=Ladv−D+λ1Lalign D+λ2Lcls−D+λ3Lma−gp,LG=Ladv−G+λ4LalignG +λ5Lcls−G
4.4.1、真实性鉴别
这部分和DF-GAN相同,使用铰链损失作为对抗损失:
L
a
d
v
−
D
=
E
[
max
(
0
,
1
−
D
(
x
)
)
]
+
E
[
max
(
0
,
1
+
D
(
x
f
)
)
]
L
a
d
v
−
G
=
−
E
[
D
(
x
f
)
]
\begin{array}{l} \mathcal{L}_{a d v_{-} D}=\mathbb{E}[\max (0,1-D(x))]+\mathbb{E}\left[\max \left(0,1+D\left(x_{f}\right)\right)\right] \\ \mathcal{L}_{a d v_{-} G}=-\mathbb{E}\left[D\left(x_{f}\right)\right] \end{array}
Ladv−D=E[max(0,1−D(x))]+E[max(0,1+D(xf))]Ladv−G=−E[D(xf)]
4.4.2、多属性分类
将多属性分类设置为辅助任务,让鉴别器学习识别给定图像中的多个属性。为了消除两个不同任务的学习过程中的偏差,使分类器能够在分类属性标签时区分真假,
L
b
c
e
(
l
,
y
)
=
−
1
2
n
∑
i
=
1
2
n
(
y
i
log
(
l
i
)
+
(
1
−
y
i
)
log
(
1
−
l
i
)
)
L
c
l
s
−
D
=
L
b
c
e
(
l
r
,
y
r
)
+
L
b
c
e
(
l
f
,
y
f
)
,
L
c
l
−
G
=
L
b
c
e
(
l
f
,
y
r
)
−
L
b
c
e
(
l
f
,
y
f
)
,
\begin{aligned} \mathcal{L}_{b c e}(l, y)=-\frac{1}{2 n} & \sum_{i=1}^{2 n}\left(y^{i} \log \left(l^{i}\right)+\left(1-y^{i}\right) \log \left(1-l^{i}\right)\right) \\ \mathcal{L}_{c l s_{-} D} &=\mathcal{L}_{b c e}\left(l_{r}, y_{r}\right)+\mathcal{L}_{b c e}\left(l_{f}, y_{f}\right), \\ \mathcal{L}_{c l_{-} G} &=\mathcal{L}_{b c e}\left(l_{f}, y_{r}\right)-\mathcal{L}_{b c e}\left(l_{f}, y_{f}\right), \end{aligned}
Lbce(l,y)=−2n1Lcls−DLcl−Gi=1∑2n(yilog(li)+(1−yi)log(1−li))=Lbce(lr,yr)+Lbce(lf,yf),=Lbce(lf,yr)−Lbce(lf,yf),
4.4.3、跨模态对齐
作者合并了真实图像之间的所有对比度损失函数x 以及相应的文本嵌入{
e
s
e_s
es,
e
a
e_a
ea} 优化鉴别器 此外,伪图像之间的对比度损失函数
x
f
x_f
xf 和{
e
s
e_s
es,
e
a
e_a
ea} 用于规范生成器. 对应的对准损失函数:
L
alignD
=
L
attrreal
+
L
sentreal
L
alignG
=
L
attrfake
+
L
sentfake
+
L
img
\mathcal{L}_{\text {alignD}}=\mathcal{L}_{\text {attrreal }}+\mathcal{L}_{\text {sentreal }} \\ \mathcal{L}_{\text {alignG}}=\mathcal{L}_{\text {attrfake }}+\mathcal{L}_{\text {sentfake }}+\mathcal{L}_{\text {img }}
LalignD=Lattrreal +Lsentreal LalignG=Lattrfake +Lsentfake +Limg
五、实验
5.1、实验设置
数据集:CUB-Birds、COCO
评价指标:FID、IS、top-1 Acc(评估语义一致性)、mAP(评估多属性分类性能)
实验细节:DF-GAN作为主干网络,Adam优化器、生成器学习率0.0001、鉴别器学习率0.0004,其他细节见原文。
5.2、实验结果
5.3、消融实验
样本联合策略比样本感知策略取得了更好的结果,因为它建模了全局相关性并获得了更合适的属性记忆。此外,样本联结和对齐策略的组合可以获得最佳结果。
下表报告了属性内存的重要性。第3-4行表示使用属性嵌入作为内存初始化的方法。第3行使用固定内存,而第4行使用可学习内存。根据第2行和第4行的比较结果,有必要使用属性嵌入进行内存初始化。与第3行和第4行相比,发现使用更新的内存而不是固定内存有助于模型训练,并能很好地提高性能。
下表报告了哪里插入句子和属性嵌入作用更佳:
六、讨论和结论
多属性描述提供了样本的一般内容,句子提供了属性之间的关联。两者的结合可以合成更逼真的图像和语义匹配的图像。
文章主要贡献在于提出了一种有效的文本表示方法,并补充了属性信息,以帮助控制图像生成。
- 首先,我们构造了一个属性内存来联合控制文本到图像的生成和句子输入。借助属性记忆,丰富了输入文本的表示,从而减少了跨模态间隙。
- 其次,我们探索了两种属性内存更新机制,样本感知和样本联合机制,以动态优化广义属性内存。样本联合机制优于样本感知机制,因为它对数据集中属性之间的全局相关性进行建模。
- 最后,作者在属性到图像、句子到图像和图像到图像中使用对比学习,以促进跨模态对齐。
结合以上所有策略,该方法在CUB和COCO数据集上都取得了显著的性能改进。
附、文本表征 in T2I
从单个句子生成图像是一个从少到多的信息生成过程,这使得生成模型很难进行优化,为了缓解这一问题,许多研究都致力于丰富文本表征。
- 提供附加信息。RiFeGAN等将多条文本描述组合到一起,Gilt使用长文本进行合成,Chatpainter和LANTERN利用视觉问答来丰富细节内容,这些附加说明为文本生成图像带来了丰富的细节,减少了不同模态的差距。
- 从一个句子中挖掘更多表征。AttnGAN使用单词特征使网络关注单词级信息,VICTR分解句子的主谓宾,生成场景嵌入,Dae-gen从一个句子中提取层面信息,CookGAN分别对食材和配方进行建模说明。作者同样采用了这种方法,从一个句子中挖掘对象的属性表征。
最后
💖 个人简介:人工智能领域研究生,目前主攻文本生成图像(text to image)方向
📝 关注我:中杯可乐多加冰
🔥 限时免费订阅:文本生成图像T2I专栏
🎉 支持我:点赞👍+收藏⭐️+留言📝