让扩散Transformer训练更容易!谢赛宁等人提出REPA:表征对齐技术

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

点击进入—>【Mamba/多模态/扩散】交流群

添加微信号:CVer111,小助手会拉你进群!

扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!

bfa89e63f0466c3415a7a4d297c1ec93.png

转载自:机器之心 |编辑:Panda、小舟

Representation matters. Representation matters. Representation matters.

是什么让纽约大学著名研究者谢赛宁三连呼喊「Representation matters」?他表示:「我们可能一直都在用错误的方法训练扩散模型。」即使对生成模型而言,表征也依然有用。基于此,他们提出了 REPA,即表征对齐技术,其能让「训练扩散 Transformer 变得比你想象的更简单。」

830eb26cd488aa2b5f6a726f161a76d5.png

Yann LeCun 也对他们的研究表示了认可:「我们知道,当使用自监督学习训练视觉编码器时,使用具有重构损失的解码器的效果远不如使用具有特征预测损失和崩溃预防机制的联合嵌入架构。这篇来自纽约大学 @sainingxie 的论文表明,即使你只对生成像素感兴趣(例如使用扩散 Transformer 生成漂亮图片),也应该包含特征预测损失,以便解码器的内部表征可以根据预训练的视觉编码器(例如 DINOv2)预测特征。」

0239dc3f73aff5d1a570309e289a886c.png

我们知道,在生成高维视觉数据方面,基于去噪的生成模型(如扩展模型和基于流的模型)的表现非常好,已经得到了广泛应用。近段时间,也有研究开始探索将扩展模型用作表征学习器,因为这些模型的隐藏状态可以捕获有意义的判别式特征。

而谢赛宁指导的这个团队发现(另一位指导者是 KAIST 的 Jinwoo Shin),训练扩散模型的主要挑战源于需要学习高质量的内部表征。他们的研究表明:「当生成式扩散模型得到来自另一个模型(例如自监督视觉编码器)的外部高质量表征的支持时,其性能可以得到大幅提升。」

REPresentation Alignment(REPA),即表征对齐技术,便基于此而诞生了。这是一个基于近期的扩散 Transformer(DiT)架构的简单正则化技术。

8e1363b948a2e55b7b6fb77f36baf577.png

  • 论文标题:Representation Alignment for Generation: Training Diffusion Transformers Is Easier Than You Think

  • 论文地址:https://arxiv.org/pdf/2410.06940

  • 项目地址:https://sihyun.me/REPA/

  • 代码地址:https://github.com/sihyun-yu/REPA

本质上讲,REPA 就是将一张清晰图像的预训练自监督视觉表征蒸馏成一个有噪声输入的扩展 Transformer 表征。这种正则化可以更好地将扩展模型表征与目标自监督表征对齐。

方法看起来很简单,但 REPA 的效果却很好!据介绍,REPA 能大幅提升模型训练的效率和效果。相比于原生模型,REPA 能将收敛速度提升 17.5 倍以上。在生成质量方面,在使用带引导间隔(guidance interval)的无分类器引导时,新方法取得了 FID=1.42 的当前最佳结果。

REPA:用于表征对齐的正则化

REPresentation Alignment(REPA)是一种简单的正则化方法,其使用了近期的扩展 Transformer 架构。简单来说,该技术就是一种将预训练的自监督视觉表征蒸馏到扩展 Transformer 的简单又有效的方法。这让扩散模型可以利用这些语义丰富的外部表征进行生成,从而大幅提高性能。

4078647dde858ce0b00c998691a00626.png

观察

REPA 的诞生基于该团队得到的几项重要观察。

他们研究了在 ImageNet 上预训练得到的 SiT(可扩展插值 Transformer)模型的逐层行为,该模型使用了线性插值和速度预测(velocity prediction)进行训练。他们研究的重点是扩散 Transformer 和当前领先的监督式 DINOv2 模型之间的表征差距。他们从三个角度进行了研究:语义差距、特征对齐进展以及最终的特征对齐。

对于语义差距,他们比较了使用 DINOv2 特征的线性探测结果与来自 SiT 模型(训练了 700 万次迭代)的线性探测结果,采用的协议涉及到对扩散 Transformer 的全局池化的隐藏状态进行线性探测。

接下来,为了测量特征对齐,他们使用了 CKNNA;这是一种与 CKA 相关的核对齐(kernel alignment)指标,但却是基于相互最近邻。这样一来,便能以量化方式评估对齐效果了。图 2 总结了其结果。

522eab90e714a0caf6cfc2f2f20f4343.png

扩散 Transformer 与先进视觉编码器之间的语义差距明显。如图 2a 所示,可以观察到,预训练扩散 Transformer 的隐藏状态表征在第 20 层能得到相当高的线性探测峰值。但是,其性能仍远低于 DINOv2,表明这两种表征之间存在相当大的语义差距。此外,他们还发现,在此峰值之后,线性探测性能会迅速下降,这表明扩散 Transformer 必定从重点学习语义丰富的表征转向了生成具有高频细节的图像。

扩散表征已经与其它视觉表征(细微地)对齐了。图 2b 使用 CKNNA 展示了 SiT 与 DINOv2 之间的表征对齐情况。可以看到,SiT 模型表征的对齐已经优于 MAE,而后者也是一种基于掩码图块重建的自监督学习方法。但是,相比于其它自监督学习方法之间的对齐分数,其绝对对齐分数依然较低。这些结果表明,尽管扩散 Transformer 表征与自监督视觉表征存在一定的对齐,但对齐程度不高。

当模型增大、训练变多时,对齐效果会更好。该团队还测量了不同模型大小和训练迭代次数的 CKNNA 值。图 2c 表明更大模型和更多训练有助于对齐。同样地,相比于其它自监督视觉编码器之间的对齐,扩散表征的绝对对齐分数依然较低。

这些发现并非 SiT 模型所独有,其它基于去噪的生成式 Transformer 也能观察到。该团队也在 DiT 模型上观察到了类似的结果 —— 其使用 DDPM 目标在 ImageNet 上完成了预训练。

与自监督表征的表征对齐

REPA 将模型隐藏状态的 patch-wise 投影与预训练自监督视觉表征对齐。具体来说,该研究使用干净的(clean)图像表征作为目标并探讨其影响。这种正则化的目的是让扩散 transformer 的隐藏状态从包含有用语义信息的噪声输入中预测噪声不变、干净的视觉表征。这能为后续层重建目标提供有意义的引导。

形式上,令 𝑓 为预训练编码器,x* 为干净图像。令 y*=𝑓(x*) ∈ ℝ^{N×D} 为编码器输出,其中 N、D > 0 分别是 patch 的数量和 𝑓 的嵌入维度。

REPA 是将9adc22b901165b2fa5c8af3c2c6395e6.png与 y* 对齐,其中a774f9142701f78f56c9eefc9494e053.png是扩散 transformer 编码器输出c1bd8eba3d20f85e84668df0a314b5bd.png通过可训练投影头 h_ϕ 得到的投影。实践中 h_ϕ 的参数化是简单地使用多层感知器(MLP)完成的。

特别地,REPA 通过最大化预训练表征 y* 和隐藏状态 h_t 之间的 patch-wise 相似性来实现对齐,其中 n 是 patch 索引,sim (・,・) 是预定义的相似度函数。

6dbcacb306b8060f2417d0b348639dbe.png

在实践中,是基于一个系数 λ 将该项添加到基于扩散的原始目标中。例如,对于速度模型的训练,其目标变为:

9615cf5b8b8e9ee25c1a8989014f1e1c.png

其中 λ > 0 是一个超参数,用于控制去噪和表示对齐之间的权衡。该团队主要研究这种正则化对两个常用目标的影响:DiT 中使用的改进版 DDPM 和 SiT 中使用的线性随机插值,尽管也可以考虑其他目标。

结果

REPA 改善视觉扩展

该研究首先比较两个 SiT-XL/2 模型在前 400K 次迭代期间生成的图像,其中一个模型应用 REPA。两种模型共享相同的噪声、采样器和采样步骤数,并且都不使用无分类器引导。使用 REPA 训练的模型表现更好。

7fab892f69f623ce125b2de0dd882201.png

REPA 在各个方面都展现出强大的可扩展性

该研究通过改变预训练编码器和扩散 transformer 模型大小来检查 REPA 的可扩展性,结果表明:与更好的视觉表征相结合可以改善生成和线性探测结果。 

06e3e0c2c05d48170126c8b7b469bd40.png

REPA 还在大型模型中提供了更显著的加速,与普通模型相比,实现了更快的 FID-50K 改进。此外,增加模型大小可以在生成和线性评估方面带来更快的增益。

REPA 显著提高训练效率和生成质量

最后,该研究比较了普通 DiT 或 SiT 模型与使用 REPA 训练的模型的 FID 值。

75ce7c9d4436290d910f70d5eed8a71a.png

在没有无分类器引导的情况下,REPA 在 400K 次迭代时实现了 FID=7.9,优于普通模型在 700 万次迭代时的性能。

使用无分类器引导,带有 REPA 的 SiT-XL/2 的性能优于最新的扩散模型,迭代次数减少为 1/7,并通过额外的引导调度实现了 SOTA FID=1.42。

该团队也执行了消融研究,探索了不同时间步数、不同视觉编码器和不同 λ 值(正则化系数)的影响。详见原论文。

 
 

何恺明在MIT授课的课件PPT下载

 
 

在CVer公众号后台回复:何恺明,即可下载本课程的所有566页课件PPT!赶紧学起来!

ECCV 2024 论文和代码下载

在CVer公众号后台回复:ECCV2024,即可下载ECCV 2024论文和代码开源的论文合集

CVPR 2024 论文和代码下载

在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集

Mamba、多模态和扩散模型交流群成立

 
 
扫描下方二维码,或者添加微信号:CVer111,即可添加CVer小助手微信,便可申请加入CVer-Mamba、多模态学习或者扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF、3DGS、Mamba等。
一定要备注:研究方向+地点+学校/公司+昵称(如Mamba、多模态学习或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

 
 
▲扫码或加微信号: CVer111,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集上万人!

▲扫码加入星球学习
 
 
▲点击上方卡片,关注CVer公众号
整理不易,请赞和在看
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值