颜水成/程明明新作!Sora核心组件DiT训练提速10倍!掩码扩散Transformer V2开源!...

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

点击进入—>【扩散模型和多模态】交流群

添加微信:CVer444,小助手会拉你进群!

扫描下方二维码,加入CVer学术星球可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文搞科研,强烈推荐!

4242cc54c9d5d4ee48ab004e4ee15fe1.jpeg

转载自:新智元 | 编辑:LRS 好困

【导读】Masked Diffusion Transformer V2在ImageNet benchmark 上实现了1.58的FID score的新SoTA,并通过mask modeling表征学习策略大幅提升了DiT的训练速度。

DiT作为效果惊艳的Sora的核心技术之一,利用Difffusion Transfomer 将生成模型扩展到更大的模型规模,从而实现高质量的图像生成。

然而,更大的模型规模导致训练成本飙升。

为此,来自Sea AI Lab、南开大学、昆仑万维2050研究院的颜水成和程明明研究团队在ICCV 2023提出的Masked Diffusion Transformer利用mask modeling表征学习策略通过学习语义表征信息来大幅加速Diffusion Transfomer的训练速度,并实现SoTA的图像生成效果。

6d3f1b07f3cc98540a2487d83e18a8f9.png

论文地址:https://arxiv.org/abs/2303.14389

GitHub地址:https://github.com/sail-sg/MDT

近日,Masked Diffusion Transformer V2再次刷新SoTA, 相比DiT的训练速度提升10倍以上,并实现了ImageNet benchmark 上 1.58的FID score。

最新版本的论文和代码均已开源。

背景

尽管以DiT 为代表的扩散模型在图像生成领域取得了显著的成功,但研究者发现扩散模型往往难以高效地学习图像中物体各部分之间的语义关系,这一局限性导致了训练过程的低收敛效率。

5f9b26e62d24873084affd5c79be895d.png

例如上图所示,DiT在第50k次训练步骤时已经学会生成狗的毛发纹理,然后在第200k次训练步骤时才学会生成狗的一只眼睛和嘴巴,但是却漏生成了另一只眼睛。

即使在第300k次训练步骤时,DiT生成的狗的两只耳朵的相对位置也不是非常准确。

这一训练学习过程揭示了扩散模型未能高效地学习到图像中物体各部分之间的语义关系,而只是独立地学习每个物体的语义信息。

研究者推测这一现象的原因是扩散模型通过最小化每个像素的预测损失来学习真实图像数据的分布,这个过程忽略了图像中物体各部分之间的语义相对关系,因此导致模型的收敛速度缓慢。

方法:Masked Diffusion Transformer

受到上述观察的启发,研究者提出了Masked Diffusion Transformer (MDT) 提高扩散模型的训练效率和生成质量。

MDT提出了一种针对Diffusion Transformer 设计的mask modeling表征学习策略,以显式地增强Diffusion Transformer对上下文语义信息的学习能力,并增强图像中物体之间语义信息的关联学习。

8c00dbe2550126a9986653337b4e47fa.png

如上图所示,MDT在保持扩散训练过程的同时引入mask modeling学习策略。通过mask部分加噪声的图像token,MDT利用一个非对称Diffusion Transformer (Asymmetric Diffusion Transformer) 架构从未被mask的加噪声的图像token预测被mask部分的图像token,从而同时实现mask modeling 和扩散训练过程。

在推理过程中,MDT仍保持标准的扩散生成过程。MDT的设计有助于Diffusion Transformer同时具有mask modeling表征学习带来的语义信息表达能力和扩散模型对图像细节的生成能力。

具体而言,MDT通过VAE encoder将图片映射到latent空间,并在latent空间中进行处理以节省计算成本。

在训练过程中,MDT首先mask掉部分加噪声后的图像token,并将剩余的token送入Asymmetric Diffusion Transformer来预测去噪声后的全部图像token。 

Asymmetric Diffusion Transformer架构

f28c3f29f585d17a1de5139ae9a6958b.png

如上图所示,Asymmetric Diffusion Transformer架构包含encoder、side-interpolater(辅助插值器)和decoder。

102c57821ee3dd017b6fb4b3fae7eeec.png

在训练过程中,Encoder只处理未被mask的token;而在推理过程中,由于没有mask步骤,它会处理所有token。

因此,为了保证在训练或推理阶段,decoder始终能处理所有的token,研究者们提出了一个方案:在训练过程中,通过一个由DiT block组成的辅助插值器(如上图所示),从encoder的输出中插值预测出被mask的token,并在推理阶段将其移除因而不增加任何推理开销。

MDT的encoder和decoder在标准的DiT block中插入全局和局部位置编码信息以帮助预测mask部分的token。

Asymmetric Diffusion Transformer V2

b90831bda2a602afcfba2b9476b3f8ce.png

如上图所示,MDTv2通过引入了一个针对Masked Diffusion过程设计的更为高效的宏观网络结构,进一步优化了diffusion和mask modeling的学习过程。

这包括在encoder中融合了U-Net式的long-shortcut,在decoder中集成了dense input-shortcut。

其中,dense input-shortcut将添加噪后的被mask的token送入decoder,保留了被mask的token对应的噪声信息,从而有助于diffusion过程的训练。

此外,MDT还引入了包括采用更快的Adan优化器、time-step相关的损失权重,以及扩大掩码比率等更优的训练策略来进一步加速Masked Diffusion模型的训练过程。

实验结果

ImageNet 256基准生成质量比较

24c1d36e1ca0af20dec070fbd28d099e.png

上表比较了不同模型尺寸下MDT与DiT在ImageNet 256基准下的性能对比。

显而易见,MDT在所有模型规模上都以较少的训练成本实现了更高的FID分数。

MDT的参数和推理成本与DiT基本一致,因为正如前文所介绍的,MDT推理过程中仍保持与DiT一致的标准的diffusion过程。

对于最大的XL模型,经过400k步骤训练的MDTv2-XL/2,显著超过了经过7000k步骤训练的DiT-XL/2,FID分数提高了1.92。在这一setting下,结果表明了MDT相对DiT有约18倍的训练加速。

对于小型模型,MDTv2-S/2 仍然以显著更少的训练步骤实现了相比DiT-S/2显著更好的性能。例如同样训练400k步骤,MDTv2以39.50的FID指标大幅领先DiT 68.40的FID指标。

更重要的是,这一结果也超过更大模型DiT-B/2在400k训练步骤下的性能(39.50 vs 43.47)。

ImageNet 256基准CFG生成质量比较

341188109adfd0d4b6a777bfeec81e53.png

我们还在上表中比较了MDT与现有方法在classifier-free guidance下的图像生成性能。

MDT以1.79的FID分数超越了以前的SOTA DiT和其他方法。MDTv2进一步提升了性能,以更少的训练步骤将图像生成的SOTA FID得分推至新低,达到1.58。

与DiT类似,我们在训练过程中没有观察到模型的FID分数在继续训练时出现饱和现象。

f136b99967d5f348ff98cf061f07eca5.png

MDT在PaperWithCode的leaderboard上刷新SoTA

收敛速度比较

944ad3b9a14e850bb2ca184d0fa2ff96.png

上图比较了ImageNet 256基准下,8×A100 GPU上DiT-S/2基线、MDT-S/2和MDTv2-S/2在不同训练步骤/训练时间下的FID性能。

得益于更优秀的上下文学习能力,MDT在性能和生成速度上均超越了DiT。MDTv2的训练收敛速度相比DiT提升10倍以上。

MDT在训练步骤和训练时间方面大相比DiT约3倍的速度提升。MDTv2进一步将训练速度相比于MDT提高了大约5倍。

例如,MDTv2-S/2仅需13小时(15k步骤)就展示出比需要大约100小时(1500k步骤)训练的DiT-S/2更好的性能,这揭示了上下文表征学习对于扩散模型更快的生成学习至关重要。

总结&讨论

MDT通过在扩散训练过程中引入类似于MAE的mask modeling表征学习方案,能够利用图像物体的上下文信息重建不完整输入图像的完整信息,从而学习图像中语义部分之间的关联关系,进而提升图像生成的质量和学习速度。

研究者认为,通过视觉表征学习增强对物理世界的语义理解,能够提升生成模型对物理世界的模拟效果。这正与Sora期待的通过生成模型构建物理世界模拟器的理念不谋而合。希望该工作能够激发更多关于统一表征学习和生成学习的工作。

参考资料:

https://arxiv.org/abs/2303.14389

何恺明在MIT授课的课件PPT下载

在CVer公众号后台回复:何恺明,即可下载本课程的152页课件PPT!赶紧学起来!

CVPR 2024 论文和代码下载

在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集

多模态和扩散模型交流群成立

 
 
扫描下方二维码,或者添加微信:CVer444,即可添加CVer小助手微信,便可申请加入CVer-多模态和扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF等。
一定要备注:研究方向+地点+学校/公司+昵称(如多模态或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

 
 
▲扫码或加微信号: CVer444,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集近万人!

▲扫码加入星球学习
 
 
▲点击上方卡片,关注CVer公众号
整理不易,请点赞和在看
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值