Consistory 无需训练的人物一致性生成方法
来自于论文《Training-Free Consistent Text-to-Image Generation》
目前的文本生成图像模型在不同提示下始终如一地呈现相同主题存在挑战。现有方法通过微调模型或向模型添加图像条件,但这些方法耗时较久甚至需要大规模的预训练。此外,它们在生成的图像与文本提示对齐方面存在困难,添加图像条件之后往往会使得模型生成结果与提示词相差较大,并且难以表现多个主题。Consistory有效的解决了这些问题,其无需训练,推理速度快,且生图效果非常好,在多主题一致性中也有不俗的表现
同时,其有效解决了先前工作中,加入一致性要求后,模型很难遵守prompt生成的问题
下面我们详细介绍下Consistory到底是怎么做到的
subjects-specific masks
在人物一致性中,我们希望做到人物尽可能相似,同时其背景尽可能不相似以达到生成图像的多样性。Consistory方法的核心是操作注意力层,所以其中关键的问题就是在模型的自注意力层self-attention
中,把人物特征和背景特征区分开来。
作者用一个简单但有效的方法解决了这一问题。在模型推理的过程中,记录所有出现在cross-attention
层,分辨率为32x32,与token相关联的attention-map
比如,在len(prompt)=3时,由于采用classifier-free guidance(x2)和多头注意力机制(x20),进入unet的第二个CrossAttnDownBlock2D
块的cross-attention
层时,attention_probs
的维度为[120, 1024, 77],由于有一半的prompt是空字符串,所以实际有用的attention_probs
的维度为[60, 1024, 77]。假如prompt的格式均为“a cat …”,那么subject_token就等于2,所以我们需要把attention_probs[60 : , : , 2]保存起来,并通过相加的方式,按照prompt分组,将其压缩为[3, 1024 ]大小形成attention_map
当我们需要subjects-specific masks时,将目前得到的所有attention_map
求平均,利用otsu
方法即可得到。subjects-specific masks是由0和1组成的向量,1表示这一点的特征属于人物特征,0表示背景
SDSA
sdsa全称是subject-driven self-attention,其操作对象是self-attention
层,核心目的是让代表图片特征的Q可以看到其他图片中代表subject的K和V,公式描述如下:
K
+
=
[
K
1
⊕
K
2
⊕
…
⊕
K
N
]
∈
R
N
⋅
P
×
d
k
V
+
=
[
V
1
⊕
V
2
⊕
…
⊕
V
N
]
∈
R
N
⋅
P
×
d
v
M
i
+
=
[
M
1
…
M
i
−
1
⊕
1
⊕
M
i
+
1
…
M
N
]
A
i
+
=
softmax
(
Q
i
K
+
⊤
d
k
+
log
M
i
+
)
∈
R
P
×
N
⋅
P
h
i
=
A
i
+
⋅
V
+
∈
R
P
×
d
v
K^+ = [K_1 \oplus K_2 \oplus \ldots \oplus K_N] \in \mathbb{R}^{N \cdot P \times d_k} \\ V^+ = [V_1 \oplus V_2 \oplus \ldots \oplus V_N] \in \mathbb{R}^{N \cdot P \times d_v} \\M^+_i = [M_1 \ldots M_{i-1} \oplus 1 \oplus M_{i+1} \ldots M_N] \\A^+_i = \text{softmax} \left( \frac{Q_i K^{+\top}}{\sqrt{d_k}} + \log M^+_i \right) \in \mathbb{R}^{P \times N \cdot P} \\h_i = A^+_i \cdot V^+ \in \mathbb{R}^{P \times d_v}
K+=[K1⊕K2⊕…⊕KN]∈RN⋅P×dkV+=[V1⊕V2⊕…⊕VN]∈RN⋅P×dvMi+=[M1…Mi−1⊕1⊕Mi+1…MN]Ai+=softmax(dkQiK+⊤+logMi+)∈RP×N⋅Phi=Ai+⋅V+∈RP×dv
符号 ⊕ \oplus ⊕表示向量或矩阵的连接操作, M i M_i Mi就是上一节得到的关于某个图片的subjects-specific mask,这里的 M i + M^+_i Mi+是 V + V^+ V+的掩码,所以第四步计算 A i + A^+_i Ai+时首先要对 log M i + \log M^+_i logMi+ 进行广播,使其从1维变到与Q相同维度,由于在实际操作中使用的是多头注意力,所以这里应该是不同图片对应的注意力头相互观察,不对应的相互不影响
Enriching layout diversity
使用了SDSA和掩码机制之后,恢复了prompt alignment
但是这导致了不同图像中的subject生成在了相似的位置且具有相似的姿势,针对这一问题,作者提出了两个方法用于改进
Using Vanilla Query Features
将生成图片初始latent的随机种子固定下来,先用原始模型生成一遍,记录下每一步自注意力层的Q值,之后用相同的随机种子重新生成,这次将自注意力层替换为SDSA,不同的是,将将每一步SDSA中的Q与之前记录的Q进行融合
Q
t
∗
=
(
1
−
ν
t
)
Q
t
SDSA
+
ν
t
Q
t
vanilla
Q^*_t = (1 - \nu_t) Q^{\text{SDSA}}_t + \nu_t Q^{\text{vanilla}}_t
Qt∗=(1−νt)QtSDSA+νtQtvanilla
Self-Attention Dropout
在每个去噪步骤中,随机将 M i M_i Mi中的一部分patch置为0。这削弱了不同图像之间的注意力共享,从而促进更丰富的布局变化。另外,通过调整dropout概率,可以调节一致性的强度,在人物一致性和布局变化之间找到平衡
Feature injection
上述的方法有效提升了人物一致性,但是在捕捉和再现细微的视觉特征方面时遇到挑战。这些细微的视觉特征对于识别和区分图像中的人物是非常重要的。如果模型不能准确地再现这些细节特征,就可能导致生成的图像中主体的身份特征不够清晰或准确,从而影响对主体的正确识别,为了更进一步的解决这个问题,作者又提出了Feature injection
特征注入中使用的特征是DIFT,来源于论文《Emergent Correspondence from Image Diffusion》,在文中作者写到we extract the feature maps of its intermediate layers at a specific time step t during the backward process,也就是说这一特征来源于推理过程中Unet的中间的某一层中。这一特征很有意思,作者发现,给定两张图片并计算出其DIFT特征,那么在特征中相似的点对应于原始图片中语义上相似的两个区域。
在这篇文章中,作者首先用原始模型推理到固定的某一个timestep,得到每一张图片的DIFT,之后计算出DIFT map
,用于描述一张图片的DIFT上的一个点与哪一张图片DIFT的哪一个点最为接近。之后用相同的随机种子重新生成,在生成过程中,利用DIFT map
中的信息,在特定的timesteps中进行DIFT融合,在融合的过程中接受subjects-specific masks的指导,只对subject对应的DIFT中的点进行融合
x
out
(
t
)
=
(
1
−
α
)
⋅
x
out
(
t
)
+
α
⋅
src
(
x
out
(
t
)
)
\mathbf{x}_{\text{out}}^{(t)} = (1 - \alpha) \cdot \mathbf{x}_{\text{out}}^{(t)} + \alpha \cdot \text{src}(\mathbf{x}_{\text{out}}^{(t)})
xout(t)=(1−α)⋅xout(t)+α⋅src(xout(t))
Anchor images and reusable subjects
为了加快推理时间,作者设计了Anchor,在之前的生成过程中,每一张图片要观察其他所有图片,这就导致了很大的attention计算量。作者发现,可以通过指定某几张图片为anchor,anchor之间相互参考,所有的其他图片的sdsa计算和feature注入只参考anchor,相互之间不再参考。作者发现只需两张anchor就完全足够
另外只需固定prompt和seed,我们就可以在其他的batch中再现相同的anchor,从而做到无穷无尽的一致性图片生成
Multi-subject consistent generation
个性化方法在单幅图像中保持多个主体的一致性方面存在困难 。然而,使用ConsiStory,通过简单地取主体掩码的并集,可以以一种简单明了的方式实现多主体一致生成。当主体在语义上不同的时候,它们之间的信息泄漏并不是一个问题。这是因为注意力softmax的指数形式起到了门控作用,抑制了不相关主体之间的信息泄漏。同样,在特征注入过程中对对应图进行阈值处理也产生了类似的门控效果,防止了信息泄漏。