利用DreamBooth实现对于文生图扩散模型的微调
文章目录
Google Research在2022年8月提出的一种 全新的对于文生图扩散模型(Text-to-Image Diffusion Model)的微调方式,在之后也成功入选 CVPR 2023,论文链接: DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation
DreamBooth的效果是相当卓越的,我们能够让模型“记住”某个物体长什么样,然后像照相馆一样,给你拍出各种各样的照片来。
所以接下来我们先来介绍一下微调技术出现的一些背景。
1.背景
- 目前出现的很多文生图大模型已经具备了相当卓越的生成能力,例如Stable-Diffusion、Imagen等等。
- 如果我们在Prompt中输入一个 [Van Gogh Style],或许它能很快给你画出梵高风格的一幅画,但我作为一个不知名的人,我也不能奢求Stable Diffusion在训练的时候把我的画也加进去是吧?所以如果我们能自己个性化调整一下这个模型就好了。
- 也正是因为上面说的,结合扩散模型本身生成内容的多样性我们可以知道:同一个提示词生成几十几百次得到的图片可能都不一样,现在的模型不能保证生成相同物体的一致性。
- 而如果依照我们的自然想法,那肯定是把整个模型重新训练调整一下,但是这样一来首先我们肯定需要比较大的数据样本,二来对于整个大模型进行重新训练的成本非常高,训练整个Stable-Diffusion-1.4大概要15万GPU小时,如果是这样对于我们来说就变得很不可及了。
2.DreamBooth
于是,Google Research团队提出了一种名为DreamBooth的全新微调方法。
原本的Text-to-Image模型(左侧红色的部分)能够根据"A dog"这条Prompt生成不同的狗的照片,经过DreamBooth微调后的模型能够在接收"A [V] dog"和"A dog"的两种情况下,分别绘制出曾经记住的那条狗以及扩散模型本身能够画出的这些狗。
这也就是说,DreamBooth方法对于有或没有标识符的Prompt生成图像的情况下,采取的是相同的权重,这也就告诉我们:DreamBooth方法是对整个扩散模型进行的微调。
Google Research在论文中对DreamBooth中提到:“It’s like a photo booth, but once the subject is captured, it can be synthesized wherever your dreams take you”, DreamBooth就像一个照相馆,能够让我们像拍照片一样生成各种各样的图片。
3.基本效果
(1).一般效果
这张图片中展示了经过微调后的生成模型与图片指导的DALL-E2,文本指导的Imagen在生成同一个物体上的效果。
我们输入的是三种同一个闹钟的照片,闹钟的一个重要特征是白底,且右边有一个黄色的数字3,然后我们送进各种模型里,再让它们生成图片,结果发现:OpenAI的DALL-E2在物体的保真性和新内容生成两方面的效果都很差,不仅闹钟的形状不对,连生成的三种图片的背景都是在只有在被子上的,这样就不符合我们的要求了。
然后是Google的Imagen,它比DALL-E2要好一点,至少背后有了新的场景,不过这个闹钟还是没有像以前一样。
最后就是DreamBooth微调后的模型,这个效果真的可以说是相当好了,物体保真性和新内容生成都成功兼顾到了,当然,它也是有一定瑕疵的,比如我们可以看到DreamBooth放大的那一张图片上,分钟刻度是有一些模糊的,不过这可能也正好能够说明:这图确实是AI画的,不是照片。
(2).和Textual Inversion对比
DreamBooth主要的对手就是之前提出的Textual Inversion,让我们来看看效果:
相比之下效果还是比较明显的对吧?中间的两组图片是用DreamBooth微调过后的Imagen和Stable Diffusion模型画出来的,而最后一个是用Textual Inversion微调后画出来的图片,Textual Inversion看起来是范围太大,把学习到的花瓶的纹理迁移到整个画面上去了,这其实提示我们:Textual Inversion或许更适合用来做风格迁移方面的微调。
3.摘要
在这篇文章中,我们提出了一种全新的文本到图像扩散模型 “个性化”方法。给定几个作为主体的图像作为输入,我们便可实现对于预训练的文生图模型的微调,使得其将唯一标识符(Unique Identifier)与特定主体绑定。
例如我们用sks作为标识符绑定某只特定的猫,通过DreamBooth微调后的模型,则输入prompt: a sks cat running on the lawn后,就会生成一张我们微调训练时使用的猫在草坪上奔跑的图片。
一个小问题
Google Research在发表这篇论文后,并没有公开团队自己用于完成论文的DreamBooth模型代码,现在能够找到的代码基本都是在其论文基础上的第三方复现,如后续会提到的Diffusers库提供的DreamBooth训练脚本。
不过其实复现出来的效果也并不差,说明这个模型的确是比较成功的,接下来就让我们来看看论文的核心部分吧。
4.Personalization of Text-to-Image Models
(1).目标与思路
- 目标:为了能够让模型记住输入的主体,我们需要把物体植入到模型的输出域中。
- 思路:采用小样本进行模型微调。
(2).为什么要选小样本?
先说说为什么是小样本,其实很好想,比如我们如果希望让这个扩散模型学会某个书法家的字体,然后给我生成一些字的图片,那我总不能说:让人家给我把常用的三四千个汉字全写一遍吧?
这样就完全没有意义了,我们希望的就是模型能够在不那么大的代价下学习到这个书法家的字体,这个代价一般不是训练过程的代价,而是我们为了让模型学会付出的代价,比如你教小朋友十以内的加法,你不可能为了让ta学会,就打印几千几万道这样的题让ta算,这样效率太低了。
因此,对于这种微调的情况,我们要尽可能用小的样本来实现微调,而DreamBooth很理想,只需要3~5张图片就可以实现这个事情,不过这么小的样本,如果搭配上四五千轮的训练,感觉会非常容易过拟合,所以我们就有了以下的问题。
(3).那么问题来了
- 问题:在既往对于GAN模型微调的研究中,小样本量的微调会产生过拟合和模式坍塌的问题。
(4).模式坍塌是啥?
模式坍塌是GAN的训练问题,指生成器(G)只能生成真实数据分布中的一部分或一种模式,而忽略了其他的多样性。例如,用GAN生成手写数字图像,有时GAN可能只生成其中一种或几种数字,这就是模式坍塌。
一般来说,模式坍塌是由于生成器和判别器(D)之间的对抗关系不平衡造成的。如果D太强,G就会找到一种能骗过D的模式,并且不断重复这种模式;如果D太弱,G就会找到一种最容易生成的模式,且不再探索其他的模式。
还有点迷糊,我再换个例子解释一下:你的高中班主任是一个非常非常严格的人,ta给你们班定了一系列非常严格的班规,你觉得非常不自在,作为一个不是那么听话的孩子,你决定反其道而行之,在充分研究了班规后,你发现了一个漏洞,然后完成了你想做的事情。
另一个方面,你的高中班主任完全不管你,这时候你发现我直接当着ta的面玩手机都行,于是你也就不再去像别的办法做你想做的事情了。
这就对应了上面说的两种对抗关系不平衡的情况,当然我没有说高中班主任怎么样的意思,只是举个例子。
对于模式坍塌这个问题,我的理解就是:AI模型本身其实也是一大堆数值堆起来的数据模型,它暂时还没有如此智能,因此训练的时候,我们设定的Loss函数只是指导它应该往哪个方向收敛,并没有规划路径,因此它完全有可能走一些“歪门邪道”来解决问题,这和我们后面还会提到的Language Drift是相似的问题。
(5).那咋办?
不咋办,DreamBooth是对扩散模型做的微调,而大的扩散模型“看样子”很擅长在不丢失原有的参数以及对小数据样本产生过拟合的情况下,将新的信息整合进入其输出域中。
其实也好理解,我之前提过,大的扩散模型是有相当强的生成多样性的,那么从这个方面上来看,整合新信息对于扩散模型来说应该并不困难。
5.Designing Prompts for Few-Shot Personalization
(1).构造提示词的方法
还真别小看这个部分,提示词的构造其实对于微调模型相当重要,毕竟你是在尝试让模型给你输出原本的物体,那你肯定要用比较精确的提示词才能有好的效果。
DreamBooth的提示词是这样构成的:a [identifier] [class noun],identifier是一个标识符,一般要求使用比较稀有的、不常见的词作为标识符;而class noun是对于需要标识的物体的“类别描述词”,例如dog, cat等等。
(2).为什么要选择稀有词而不是随机词?
DreamBooth特别要求,在构造提示词的时候要尽可能使用词典中存在的不常见/稀有词,而不是常见词或者随机生成的词。
构造提示词最简单的想法其实就是找个存在的词,比如我要记住一只猫,提示词就让它是a cat,你非说这样不行吧,也不至于,但是这样就会使得扩散模型在这个小样本上过拟合,然后失去通过这个词生成图片的泛化能力,毕竟你要让扩散模型准确输出这只猫,那就肯定要丢掉其他猫的记忆了,这也就是一种过拟合的问题。
那既然上面这条路走不通,那我就让你完全认不出来吧,比如文中随机构造了一个"xxy5syt00"作为标识符,这总行了吧?不行哦,这样构造的词可能效果跟上面那种一样差。 为什么?这扩散模型不可能认识啊,其实这是犯了一个主观上的错误,它确实不认识这个词,但是分词器可能也不认识啊,它在接收我们的提示词输入的时候,有可能会根据一定的分词方式,将这种随机标识符分成几个词,而可能其中某个词,扩散模型是认识的,这样一来,就回到的第一种情况中说的过拟合了。
比如我是一个没什么新意的人,我给这个标识符起一个:myspecialcat,有可能经过分词后就变成了my special cat对应的向量,这就麻烦大了,我们到最后可能得到跟a cat一样的效果了。
所以,综上所述,作者在词表中选择罕见词来作为特殊标记符,这样避免了预训练模型对特殊标记符有丰富的先验知识。
其实这里我们也理解它的做法是什么了,就是把这个新的主体绑定到这个稀有词上,从而完成注入的过程。
6.Class-specific Prior Preservation Loss
(1).新的问题!
-
问题1:Language Drift,对模型进行微调可能产生在Countering Language Drift with Seeded Iterated Learning1以及Countering Language Drift via Visual Grounding2中提到的语言漂移问题。
-
问题2:过拟合,长时间训练可能导致大模型丢失生成物体的多样性。
(2).Language Drift是什么?
#1.Countering Language Drift with Seeded Iterated Learning
Yuchen Lu等在Countering Language Drift with Seeded Iterated Learning中认为:
They slowly lose syntactic and semantic properties of language as they only focus on solving the task.
他们逐渐失去语言的句法和语义属性,而仅仅关注解决任务。
也就是说,模型在完成任务的过程中,仅仅为了完成任务,丢失了我们希望它使用的自然语言属性,这样就出现了语言漂移的问题。
#2.Countering Language Drift via Visual Grounding
Jason Lee等则在Countering Language Drift via Visual Grounding中提出
When a nonlinguistic reward is used in a goal-based task, e.g. some scalar success metric, the communication protocol may easily and radically diverge from natural language.
当在基于目标的任务中使用非语言奖励时,例如一些用于度量成功的指标,通信协议可能很容易从根本上偏离自然语言。
这里的通信协议指的是Multi-Agent多智体的通信协议,大概也就是指对于这样一个多智体在语言方面的输入和输出时的协议。
这一段话其实比上面那个要更好理解一点,也就是说,我们使用的Loss函数总是一个数学表达式,并没有完全使用我们对于某个目标达成的评估方式。
举个例子:你在教小朋友说话的时候,ta说:“我苹果想吃”,你知道这不对,你纠正ta说:“我想吃苹果”,这就是在依靠语言的方式来评判语言本身是否正确。
但假设我今天有一个不那么智能的程序,它接收到了我、苹果、想吃三个词之后,认为理解了你说的话,没有去纠正,就认为你是对的了,这就麻烦了,它的评判标准没有完全考虑到语法本身,如果长期保持这样一种方式进行训练,就可能导致小朋友说的话变成一种我们不太能听懂的语言了。
Language Drift大概也就是这样的问题了,我们在训练的过程当中,可能会因为评估方式的差异,导致这个模型最终在语言方面产生一种完全不属于自然语言范畴的语言,这样虽然它解决了自己的问题,但不符合我们的要求。
这里提到Language Drift在我看来可能有两重意思,第一重是论文团队可能希望借此来描述扩散模型的遗忘问题,也就是失去多样性的过程。
第二重是换个思路,Language Drift本身在语言学中的含义大概是:语言中的词、句或者用法等在时间推演过程中,由于受到社会文化、经济情况等各种原因的影响下,发生了意义变化的现象。
我们可以用微信的表情来举个例子:微信里有一个微笑的表情,最早我们也是照着微笑这个意思去发的,但是后来年轻人觉得这个表情很奇怪,虽然是在笑,但是眼珠向下翻,漏出很多白眼,而且眼睛也没有什么太大的变化,给人一种强行微笑的感觉,所以现在我们用这个表情来表示我对你说得对不感兴趣之类的意思,这样一来,我想你大概就明白语言漂移是什么了吧?
那么提到Language Drift,可能是说,我们在通过DreamBooth微调模型本身,把这个新的物体绑定到某个稀有词这个过程,本身也是属于语言漂移的过程,当然,这是我自己的一点理解,原文中的意思应该更加接近我说的上一种情况。
(3).所以,我们选择PPL
DreamBooth为解决以上问题,对Diffusion模型原有的Loss函数中加入了一个可调节的Prior Preservation Loss(PPL),Loss函数变为以下形式:
L
=
E
x
,
c
,
ϵ
,
ϵ
′
[
w
t
∣
∣
x
^
θ
(
α
t
x
+
σ
t
ϵ
,
c
)
−
x
∣
∣
2
2
+
λ
w
t
′
∣
∣
x
^
θ
(
α
t
′
x
p
r
+
σ
t
′
ϵ
′
,
c
p
r
)
−
x
p
r
∣
∣
2
2
]
\large \mathcal{L} = \mathbb{E}_{\bold{x},\bold{c},\bold{\epsilon},\bold{\epsilon'}}[w_t||\hat{\bold{x}}_\theta(\alpha_t\bold{x}+\sigma_t\bold{\epsilon},\bold{c})-\bold{x}||^2_2+\lambda w_{t'}||\hat{\bold{x}}_\theta(\alpha_{t'}\bold{x_{pr}}+\sigma_{t'}\bold{\epsilon'},\bold{c_{pr}})-\bold{x_{pr}}||^2_2]
L=Ex,c,ϵ,ϵ′[wt∣∣x^θ(αtx+σtϵ,c)−x∣∣22+λwt′∣∣x^θ(αt′xpr+σt′ϵ′,cpr)−xpr∣∣22] 其中
x
^
θ
\hat{\bold{x}}_\theta
x^θ为模型,
x
p
r
=
x
^
θ
(
z
t
1
,
c
p
r
)
\bold{x}_{pr}=\hat{\bold{x}}_\theta({z_{t_1},\bold{c}_{pr}})
xpr=x^θ(zt1,cpr)为预训练模型根据包含标识符的提示词生成的向量,
z
t
1
∼
N
(
0
,
I
)
z_{t_1}\sim \bold{N}(0,I)
zt1∼N(0,I),
c
p
r
\bold{c}_{pr}
cpr是通过Text-Encoder得到的包含标识符的文本条件向量,
λ
\lambda
λ是用于控制微调部分的权重。
其实这个Loss函数相较于一般的扩散模型只是加了后面这一部分,但是能够很好地保障我们在训练的时候可以保留住扩散模型的先验知识。
不过从这个新的Loss函数我们其实也很容易看出:DreamBooth真的是对扩散模型整体进行调整,其意图在于将需要模型“记住”的主体嵌入到扩散模型中,从而以最大化地保证原有主体的保真度,而由此也会带来更大的资源(时间、显存等)消耗。
(4).这里说的过拟合又是什么?
等一下,在这里我们说的过拟合到底是什么?是模型对小样本学习程度过高导致模型只能在验证集上发挥比较好的效果吗? 不对吧?我们希望微调之后的模型能够很好地记住输入物体的真实样貌,好像,就是希望它过拟合啊!
这里的过拟合实际上是说:我们希望模型记住物体,但不希望只能记住这个物体的几个状态,例如输入的狗狗只有趴着一个姿势,如果这样过拟合,可能到最后输出的图片中狗狗的姿势就全部都是趴着的了,这样的过拟合是我们不希望出现的。
7.实验
核心讲完了,接下来就来看看DreamBooth的实验吧!
(1).数据集
Google Research在DreamBooth仓库中一行代码都没有留下,不过把论文中训练模型用的数据集都放在里面了:https://github.com/google/dreambooth
(2).训练消耗
- 训练参数:训练集采用3~5张图片,在Imagen下设置学习率为1e-5,Stable Diffusion下设置学习率为5e-6,均训练1000轮
- 结果:Imagen上采取单张TPUv4训练,耗时5分钟;Stable Diffusion上采取单张A100训练,耗时5分钟
我靠,真要有这效果,那DreamBooth真是神了,时间短而且效果好,真的这么理想吗?我们等会儿再说。
(3).消融实验
有的时候我们不能直接通过模型变化理解模型变好的原因,所以可以通过消融实验,去除这些变化因素,再与包含相应变量的模型的效果进行对比,从而验证变量在模型当中的效果。
DreamBooth尝试了两个方向的消融实验:Prior Preservation Loss和Class-Prior,第一个就是验证PPL的影响,第二个则是验证类表名称正确与否对于生成图片的影响。
(4).Prior Preservation Loss Ablation
#1.先看看图
- 实验中输入的图片特地采用了同一姿势的狗狗的照片进行训练。
- 结果比较明显,在没有PPL的情况下,生成的几张图片虽然保真度较好,但是狗狗的姿势都是趴着的。
- 而有PPL的组中就能在保真的情况下生成不同的姿势。
而在没有PPL的情况下的DreamBooth微调模型,就出现了我们前面说的,不希望模型出现的过拟合现象——毕竟我们拍照片也不可能总是只有一个姿势对吧?
#2.再看看数据
很明显,有PPL的DreamBooth微调后的Imagen模型在PRES,DIV和CLIP-T上表现都要比没有PPL的模型更好,不过这个DINO和CLIP-I是怎么回事?怎么还没有另一组好呢?
事实上。DINO和CLIP-I指标都是用于衡量生成图像与真实图像之间平均相似性的指标。它反映了生成图像在视觉特征上与真实图像的接近程度,但不一定代表了生成图像的质量或多样性,因此具备PPL的DreamBooth在这些指标上并不占优势,也好理解,毕竟在这里我们希望具备一定的多样性,而DINO和CLIP-I更希望保证它的不变性。
(5).Class-Prior Ablation
Class-Prior在这里更多是针对三种情况进行评估:正确使用类别词汇,不适用类别词汇以及错误使用类别词汇,我们直接看数据:
我们仍然用DINO和CLIP-I进行评估,结果很直接:正确的类别标识福的效果要远远好于使用错误的或不适用对应的类别标识词,这也说明:选择正确的类别标识词对于微调是相当重要的。
(6).效果
接下来从数据上看看DreamBooth和Textual Inversion的对比:
无论是在Imagen还是在Stable Diffusion上的微调效果都要远远好于Textual Inversion(遥遥领先),然后看看物体保真度和文本提示保真度:
DreamBooth在物体保真度和文本提示保真度也是远远领先于Textual Inversion以及没有采取任何措施的模型。
文本提示保真度指的就是生成图片和输入的提示词之间的一致性。
8.自行训练
在这里我们就采取Diffusers库提供的一系列DreamBooth训练脚本对Stable Diffusion模型进行微调训练。
(1).利用Diffusers库进行DreamBooth训练
- 组合:Stable Diffusion XL + LoRA + DreamBooth
- 训练参数:学习率1e-5,最大训练次数5000
- 资源消耗:我也用一张A100,大约占用25GB显存,并且需要3个小时才能完成训练
这个资源消耗,看起来完全达不到论文当中的要求啊!可能Google Research他们团队的论文确实做到了吧,可惜他们没有公布出来。
这里我们用了Stable Diffusion XL + LoRA来优化微调的结果,在采取Stable Diffusion v1.4的原始模型直接微调的情况下,1000步的微调需要接近8个小时才能完成训练。
这里就不放图了,在Stable Diffusion XL + LoRA的组合下使用DreamBooth微调的结果相对还比较好,但如果换成原始模型不加其他优化方法微调,DreamBooth微调出来的模型绘制对应物体的能力是非常差的,这也不太能达到DreamBooth论文中提到的效果。
总结
Google Research提出的DreamBooth成功将我们对于扩散模型的微调带到了一个新的高度,它的效果的确是不错的,虽然我们可能在自己使用的时候和论文提供的效果有所差距,但在经过一些调整之后,效果还是能基本达到预期的。