MoA Mixture-of-Attention for Subject-Context Disentanglement in Personalized Image Generation

MoA: Mixture-of-Attention for Subject-Context Disentanglement in Personalized Image Generation

TL; DR:本文提出 MoA,注意力版本的 MoE,在文生图模型 UNet 的注意力层中加入 personalization 和 prior 两个分支,分别用于处理定制化前景物体和非定制化背景。


方法

本文提出 MoA(Mixture of Attention)架构,在文生图模型 UNet 中的注意力层中,存在 prior 和 personalization 两个分支,并有一个 router 来结合二者的输出,得到最终输出。

UNet常规注意力

在常规的文生图 UNet 模型中,存在交叉注意力和自注意力两种注意力。记网络当前的 hidden state 为 Z \mathbf{Z} Z ,注意力的计算方式如下:
Z ′ = Attention ( Q , K , V ) = Softmax ( Q K T d ) V \mathbf{Z}'=\text{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\text{Softmax}(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}})\mathbf{V} Z=Attention(Q,K,V)=Softmax(d QKT)V

Q = Z W q ,   K = C W k ,   V = C W v \mathbf{Q}=\mathbf{Z}\mathbf{W}_q,\ \mathbf{K}=\mathbf{C}\mathbf{W}_k,\ \mathbf{V}=\mathbf{C}\mathbf{W}_v Q=ZWq, K=CWk, V=CWv

其中 W \mathbf{W} W 均为线性变换矩阵。在自注意中 C = Z \mathbf{C}=\mathbf{Z} C=Z ,在交叉注意力中 C = T \mathbf{C}=\mathbf{T} C=T ,其中 T \mathbf{T} T 是文本 prompt embedding。

在常规文生图 UNet 模型中,不同的层有不同的注意力网络权重,而 UNet 又同时以时间步 为条件因此,特定层 l l l 及特定时间步 t t t 下的注意力可表示为:
Q t , l = Z t , l W q l ,   K t , l = C t , l W q l ,   V t , l = C t , l W v l \mathbf{Q}^{t,l}=\mathbf{Z}^{t,l}\mathbf{W}_q^l,\ \mathbf{K}^{t,l}=\mathbf{C}^{t,l}\mathbf{W}_q^l, \ \mathbf{V}^{t,l}=\mathbf{C}^{t,l}\mathbf{W}_v^l Qt,l=Zt,lWql, Kt,l=Ct,lWql, Vt,l=Ct,lWvl
其中各变换矩阵 W W W 在不同层 l l l 不同。hiddent state Z \mathbf{Z} Z 是关于网络层 l l l 和时间步 t t t 的函数;条件 C \mathbf{C} C 在自注意力中就是 Z \mathbf{Z} Z ,而在交叉注意力中 C = T \mathbf{C}=\mathbf{T} C=T ,文本条件 prompt embedding 本身由 CLIP 文本编码器编码,与 t t t l l l 无关。但最近也有研究(NeTI)指出,加入时空条件可以提高个性化生成的效果。

MoA

本文提出的 MoA(Mixture of Attention)灵感来自于 MoE(Mixture of Experts)。特殊之处在于两点:1) MoA 的不同 Expert 是注意力层中的线性变换权重;2) MoA 只有 prior 和 personalization 两个特定的 experts。具体的,MoA 可表示为:
Z t , l = ∑ n = 1 2 R n t , l Attention ( Q n t , l , K n t , l , V n t , l ) \mathbf{Z}^{t,l}=\sum_{n=1}^2\mathbf{R}_n^{t,l}\text{Attention}(\mathbf{Q}_n^{t,l},\mathbf{K}_n^{t,l},\mathbf{V}_n^{t,l}) Zt,l=n=12Rnt,lAttention(Qnt,l,Knt,l,Vnt,l)

R n t , l = Router l ( Z t , l ) \mathbf{R}_n^{t,l}=\text{Router}^{l}(\mathbf{Z}^{t,l}) Rnt,l=Routerl(Zt,l)

其中 R \mathbf{R} R 表示 router 路由,将不同的分支的输出进行软结合(加权和),注意这里是每个分支都有一个 router。我们将两个分支都以原注意力层网络权重进行初始化,在训练时,prior 分支权重冻结不更新,只更新 personalization 分支的权重。

在这里插入图片描述

在自注意力 MoA 中,两个分支接受的输入是相同的,都是前一层的 hidden state;而在交叉注意力 MoA 中,两个分支接受的输入是不同的。在交叉注意力中,prior 分支接受标准的文本条件 prompt embedding 作为输入,而 personalization 分支,为了同时考虑定制化参考图像,接受的是一种多模态的 prompt。

在这里插入图片描述

给定一张参考图像,首先用一个图像编码器(如 CLIP Image Encoder)提取其图像特征 embedding,然后该图像特征 embedding 与对应文本 token 的 embedding 进行拼接(如图示中,就应该与 <man> 对应的 token 进行拼接),从而得到了多模态的条件 embedding。然后,再通过一个可学习的位置编码,为这个多模态条件 embedding 加入两个条件:时间步 t t t 和当前网络层层 l l l 。最后送到一个 MLP 中。这个操作也是参考之前一些个性化生成方法中指出,加入时空条件可以提高定制化目标的一致性。

训练

router 训练的目标是让背景像素(非定制化目标的像素)更多地使用 prior 分支,前景像素不需要特别地优化。router 的损失表示为:
L router = ∣ ∣ ( 1 − M ) ⊙ ( 1 − R ) ∣ ∣ R = 1 ∣ L ∣ ∑ l ∈ L R 0 l \mathcal{L}_\text{router}=||(1-\mathbf{M})\odot(1-\mathbf{R})|| \\ \mathbf{R}=\frac{1}{|\mathbb{L}|}\sum_{l\in\mathbb{L}}\mathbf{R}_0^l Lrouter=∣∣(1M)(1R)∣∣R=L1lLR0l

其中 R 0 l \mathbf{R}_0^l R0l 表示第 l l l 层的 prior 的 router, M \mathbf{M} M 表示目标物体的掩码图, L \mathbb{L} L 表示加入 MoA 的层。实际中,只有 UNet 的前两层和后三层没有加 MoA,因为这几层是比较浅层的特征,与目标物体的关系不大。在训练完成后,理想的状态是 personalization 分支聚焦于前景物体,而 prior 分支主要负责背景,但也对前景、前背景分界处有一定的影响。

不同层 / 不同时间步的 router 可能会关注到前景物体的不同区域,比如在下图的可视化实验中,有的层关注人物的脸部,有的层则关注人物的躯体。

在这里插入图片描述

MoA 整体的训练损失由三部分组成:
L = L masked + λ r L router + λ o L object \mathcal{L}=\mathcal{L}_\text{masked}+\lambda_r\mathcal{L}_\text{router}+\lambda_o\mathcal{L}_\text{object} L=Lmasked+λrLrouter+λoLobject
其中 L router \mathcal{L}_\text{router} Lrouter 已经介绍过, L masked \mathcal{L}_\text{masked} Lmasked 是指仅对掩码前景计算的重构损失:
L masked ( Z , Z ′ ) = ∣ ∣ M ⊙ ( Z − Z ′ ) ∣ ∣ 2 2 \mathcal{L}_\text{masked}(\mathbf{Z},\mathbf{Z'})=||\mathbf{M}\odot(\mathbf{Z}-\mathbf{Z'})||_2^2 Lmasked(Z,Z)=∣∣M(ZZ)22
L object \mathcal{L}_\text{object} Lobject 采用了 FastComposer 提出的一种均衡 L1 损失:
L object = 1 L S ∑ l ∈ L ∑ s ∈ S mean ( ( 1 − M s ) ⊙ ( 1 − A s l ) ) − mean ( M s ⊙ A s l ) \mathcal{L}_\text{object}=\frac{1}{\mathbb{L}\mathbb{S}}\sum_{l\in\mathbb{L}}\sum_{s\in\mathbb{S}}\text{mean}((1-\mathbf{M}_s)\odot(1-\mathbf{A}_s^l))-\text{mean}(\mathbf{M}_s\odot\mathbf{A}_s^l) Lobject=LS1lLsSmean((1Ms)(1Asl))mean(MsAsl)
其中 S \mathbb{S} S 表示图片中物体的集合, M s \mathbf{M}_s Ms 表示物体 s s s 的分割掩码, A s l \mathbf{A}_s^l Asl 表示交叉注意力图中物体 s s s 对应的 token 的位置。

总结

MoA 的方案虽然需要训练,但是训的是定制化生成的能力,而不是像 LoRA 一样只能训某一个 subject。同时,在训练阶段加入 subject mask 让模型学习,从而在生图阶段不需要再显式地加入 mask。整体相比于之前的多概念定制化方法(Mix of show, ConsiStory, OMG, CustomDiffusion)等看起来都要更优雅一些。

  • 22
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值