MoA: Mixture-of-Attention for Subject-Context Disentanglement in Personalized Image Generation
TL; DR:本文提出 MoA,注意力版本的 MoE,在文生图模型 UNet 的注意力层中加入 personalization 和 prior 两个分支,分别用于处理定制化前景物体和非定制化背景。
方法
本文提出 MoA(Mixture of Attention)架构,在文生图模型 UNet 中的注意力层中,存在 prior 和 personalization 两个分支,并有一个 router 来结合二者的输出,得到最终输出。
UNet常规注意力
在常规的文生图 UNet 模型中,存在交叉注意力和自注意力两种注意力。记网络当前的 hidden state 为
Z
\mathbf{Z}
Z ,注意力的计算方式如下:
Z
′
=
Attention
(
Q
,
K
,
V
)
=
Softmax
(
Q
K
T
d
)
V
\mathbf{Z}'=\text{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\text{Softmax}(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}})\mathbf{V}
Z′=Attention(Q,K,V)=Softmax(dQKT)V
Q = Z W q , K = C W k , V = C W v \mathbf{Q}=\mathbf{Z}\mathbf{W}_q,\ \mathbf{K}=\mathbf{C}\mathbf{W}_k,\ \mathbf{V}=\mathbf{C}\mathbf{W}_v Q=ZWq, K=CWk, V=CWv
其中 W \mathbf{W} W 均为线性变换矩阵。在自注意中 C = Z \mathbf{C}=\mathbf{Z} C=Z ,在交叉注意力中 C = T \mathbf{C}=\mathbf{T} C=T ,其中 T \mathbf{T} T 是文本 prompt embedding。
在常规文生图 UNet 模型中,不同的层有不同的注意力网络权重,而 UNet 又同时以时间步 为条件因此,特定层
l
l
l 及特定时间步
t
t
t 下的注意力可表示为:
Q
t
,
l
=
Z
t
,
l
W
q
l
,
K
t
,
l
=
C
t
,
l
W
q
l
,
V
t
,
l
=
C
t
,
l
W
v
l
\mathbf{Q}^{t,l}=\mathbf{Z}^{t,l}\mathbf{W}_q^l,\ \mathbf{K}^{t,l}=\mathbf{C}^{t,l}\mathbf{W}_q^l, \ \mathbf{V}^{t,l}=\mathbf{C}^{t,l}\mathbf{W}_v^l
Qt,l=Zt,lWql, Kt,l=Ct,lWql, Vt,l=Ct,lWvl
其中各变换矩阵
W
W
W 在不同层
l
l
l 不同。hiddent state
Z
\mathbf{Z}
Z 是关于网络层
l
l
l 和时间步
t
t
t 的函数;条件
C
\mathbf{C}
C 在自注意力中就是
Z
\mathbf{Z}
Z ,而在交叉注意力中
C
=
T
\mathbf{C}=\mathbf{T}
C=T ,文本条件 prompt embedding 本身由 CLIP 文本编码器编码,与
t
t
t 和
l
l
l 无关。但最近也有研究(NeTI)指出,加入时空条件可以提高个性化生成的效果。
MoA
本文提出的 MoA(Mixture of Attention)灵感来自于 MoE(Mixture of Experts)。特殊之处在于两点:1) MoA 的不同 Expert 是注意力层中的线性变换权重;2) MoA 只有 prior 和 personalization 两个特定的 experts。具体的,MoA 可表示为:
Z
t
,
l
=
∑
n
=
1
2
R
n
t
,
l
Attention
(
Q
n
t
,
l
,
K
n
t
,
l
,
V
n
t
,
l
)
\mathbf{Z}^{t,l}=\sum_{n=1}^2\mathbf{R}_n^{t,l}\text{Attention}(\mathbf{Q}_n^{t,l},\mathbf{K}_n^{t,l},\mathbf{V}_n^{t,l})
Zt,l=n=1∑2Rnt,lAttention(Qnt,l,Knt,l,Vnt,l)
R n t , l = Router l ( Z t , l ) \mathbf{R}_n^{t,l}=\text{Router}^{l}(\mathbf{Z}^{t,l}) Rnt,l=Routerl(Zt,l)
其中 R \mathbf{R} R 表示 router 路由,将不同的分支的输出进行软结合(加权和),注意这里是每个分支都有一个 router。我们将两个分支都以原注意力层网络权重进行初始化,在训练时,prior 分支权重冻结不更新,只更新 personalization 分支的权重。
在自注意力 MoA 中,两个分支接受的输入是相同的,都是前一层的 hidden state;而在交叉注意力 MoA 中,两个分支接受的输入是不同的。在交叉注意力中,prior 分支接受标准的文本条件 prompt embedding 作为输入,而 personalization 分支,为了同时考虑定制化参考图像,接受的是一种多模态的 prompt。
给定一张参考图像,首先用一个图像编码器(如 CLIP Image Encoder)提取其图像特征 embedding,然后该图像特征 embedding 与对应文本 token 的 embedding 进行拼接(如图示中,就应该与 <man>
对应的 token 进行拼接),从而得到了多模态的条件 embedding。然后,再通过一个可学习的位置编码,为这个多模态条件 embedding 加入两个条件:时间步
t
t
t 和当前网络层层
l
l
l 。最后送到一个 MLP 中。这个操作也是参考之前一些个性化生成方法中指出,加入时空条件可以提高定制化目标的一致性。
训练
router 训练的目标是让背景像素(非定制化目标的像素)更多地使用 prior 分支,前景像素不需要特别地优化。router 的损失表示为:
L
router
=
∣
∣
(
1
−
M
)
⊙
(
1
−
R
)
∣
∣
R
=
1
∣
L
∣
∑
l
∈
L
R
0
l
\mathcal{L}_\text{router}=||(1-\mathbf{M})\odot(1-\mathbf{R})|| \\ \mathbf{R}=\frac{1}{|\mathbb{L}|}\sum_{l\in\mathbb{L}}\mathbf{R}_0^l
Lrouter=∣∣(1−M)⊙(1−R)∣∣R=∣L∣1l∈L∑R0l
其中 R 0 l \mathbf{R}_0^l R0l 表示第 l l l 层的 prior 的 router, M \mathbf{M} M 表示目标物体的掩码图, L \mathbb{L} L 表示加入 MoA 的层。实际中,只有 UNet 的前两层和后三层没有加 MoA,因为这几层是比较浅层的特征,与目标物体的关系不大。在训练完成后,理想的状态是 personalization 分支聚焦于前景物体,而 prior 分支主要负责背景,但也对前景、前背景分界处有一定的影响。
不同层 / 不同时间步的 router 可能会关注到前景物体的不同区域,比如在下图的可视化实验中,有的层关注人物的脸部,有的层则关注人物的躯体。
MoA 整体的训练损失由三部分组成:
L
=
L
masked
+
λ
r
L
router
+
λ
o
L
object
\mathcal{L}=\mathcal{L}_\text{masked}+\lambda_r\mathcal{L}_\text{router}+\lambda_o\mathcal{L}_\text{object}
L=Lmasked+λrLrouter+λoLobject
其中
L
router
\mathcal{L}_\text{router}
Lrouter 已经介绍过,
L
masked
\mathcal{L}_\text{masked}
Lmasked 是指仅对掩码前景计算的重构损失:
L
masked
(
Z
,
Z
′
)
=
∣
∣
M
⊙
(
Z
−
Z
′
)
∣
∣
2
2
\mathcal{L}_\text{masked}(\mathbf{Z},\mathbf{Z'})=||\mathbf{M}\odot(\mathbf{Z}-\mathbf{Z'})||_2^2
Lmasked(Z,Z′)=∣∣M⊙(Z−Z′)∣∣22
而
L
object
\mathcal{L}_\text{object}
Lobject 采用了 FastComposer 提出的一种均衡 L1 损失:
L
object
=
1
L
S
∑
l
∈
L
∑
s
∈
S
mean
(
(
1
−
M
s
)
⊙
(
1
−
A
s
l
)
)
−
mean
(
M
s
⊙
A
s
l
)
\mathcal{L}_\text{object}=\frac{1}{\mathbb{L}\mathbb{S}}\sum_{l\in\mathbb{L}}\sum_{s\in\mathbb{S}}\text{mean}((1-\mathbf{M}_s)\odot(1-\mathbf{A}_s^l))-\text{mean}(\mathbf{M}_s\odot\mathbf{A}_s^l)
Lobject=LS1l∈L∑s∈S∑mean((1−Ms)⊙(1−Asl))−mean(Ms⊙Asl)
其中
S
\mathbb{S}
S 表示图片中物体的集合,
M
s
\mathbf{M}_s
Ms 表示物体
s
s
s 的分割掩码,
A
s
l
\mathbf{A}_s^l
Asl 表示交叉注意力图中物体
s
s
s 对应的 token 的位置。
总结
MoA 的方案虽然需要训练,但是训的是定制化生成的能力,而不是像 LoRA 一样只能训某一个 subject。同时,在训练阶段加入 subject mask 让模型学习,从而在生图阶段不需要再显式地加入 mask。整体相比于之前的多概念定制化方法(Mix of show, ConsiStory, OMG, CustomDiffusion)等看起来都要更优雅一些。