极市平台 | Mamba联名Diffusion?DiM:无需微调,高分辨图像生成更高效!

本文来源公众号“极市平台”,仅用于学术分享,侵权删,干货满满。

原文链接:Mamba联名Diffusion?DiM:无需微调,高分辨图像生成更高效!

极市导读

本文提出了一种新的基于 Mamba 的扩散模型 DiM,用于高效的高分辨率图像生成。Mamba 本是用于处理一维信号的模型,作者提出了几种有效的设计来使其能够对二维图像进行建模。

0 本文目录

1 DiM:高效高分辨率图像生成的 Diffusion Mamba
(来自香港大学,华为诺亚方舟实验室)
1 DiM 论文解读
1.1 用 Mamba 架构进行高分辨率图像生成
1.2 状态空间模型
1.3 Diffusion Mamba 架构
1.4 训练和推理策略
1.5 实验设置
1.6 效率分析
1.7 实验结果

太长不看版

扩散模型在图像生成方面取得了巨大成功,Backbone 从 U-Net 演变到 Vision Transformer。然而, Transformer 的计算成本与 token 的数量成二次方,在处理高分辨率图像时面临重大挑战。本文提出 Diffusion Mamba (DiM),它结合了 Mamba 的效率,且具有扩散模型的表达能力,以实现高效的高分辨率图像合成。Mamba 是一种基于状态空间模型 (State Space Models, SSM) 的序列模型。

为了解决 Mamba 不能泛化到 2D 信号的挑战,作者提出了几种架构设计,包括多方向扫描、每行和列末尾的 learnable padding tokens 以及轻量级局部特征增强。DiM 架构可以高效地生成高分辨率图像。此外,为了进一步提高 DiM 高分辨率图像生成的训练效率,作者研究了在低分辨率图像 (256×256) 上预训练 DiM 的 "weak-to-strong" 训练策略,然后在高分辨率图像上微调它 (512×512)。作者进一步探索了 training-free 的上采样策略,使模型能够生成更高分辨率的图像 (例如 1024×1024 和 1536×1536),而无需进一步微调。实验证明了 DiM 的有效性和效率。

图1:本文模型在 ImageNet 上训练的图像生成的图像。分辨率分别为 1024×1024,512×512 和 256×256,classifier-free guidance 权重为 4.0

本文做了哪些具体的工作

  1. 提出了一种新的基于 Mamba 的扩散模型 DiM,用于高效的高分辨率图像生成。Mamba 本是用于处理一维信号的模型,作者提出了几种有效的设计来使其能够对二维图像进行建模。

  2. 为了解决高分辨率图像训练的高成本,作者研究了微调在低分辨率图像上预训练的 DiM 以进行高分辨率图像生成的策略。此外还探索了 training-free 的上采样方案,使模型在无需进一步的微调的情况下生成更高分辨率的图像。

  3. 在 ImageNet 和 CIFAR 上的实验证明了 DiM 在高分辨率图像生成中的训练效率、推理效率和有效性。

1 DiM:高效高分辨率图像生成的 Diffusion Mamba

论文名称:DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis (Arxiv 2024.05)

论文地址:

https://arxiv.org/pdf/2405.14224

代码链接:

http://github.com/tyshiwo1/DiM-DiffusionMamba/

1.1 用 Mamba 架构进行高分辨率图像生成

扩散模型在图像生成方面取得了巨大的成功。由于 Transformer 架构的有效性和可扩展性,扩散模型的 Backbone 已经从以 U-Net[1]为代表的卷积神经网络发展到 Vision Transformer[2][3][4][5]。基于 Transformer 的扩散模型将图像编码为 latent 特征图,再把 latent 特征图分成 Patches,再把这些 Patches 投影为 tokens。然后,应用 Transformer 对图像 tokens 进行去噪。但是,Transformer 中的 Self-attention 层的复杂度与 tokens 的数量呈二次方关系,使得高分辨率图像生成的计算成本面临重大挑战。

Mamba[6]是一种基于状态空间模型 (State Space Models, SSM) 的序列模型 Backbone,在语言、音频和基因组学等几种模式中显示出显著的有效性和效率。Mamba 实现了与 Transformer 相当的性能,且具有更好的推理时间和效率。与 Transformer 的二次计算复杂度相比,Mamba 在长序列建模中显示出巨大的前景,因为 Mamba 的计算复杂度与 token 的数量成线性关系。Mamba 的这些特性促使本文作者将 Mamba 作为扩散模型的新的 Backbone 引入,尤其是对于高效的高分辨率图像生成。

然而,当将 Mamba 与扩散模型相结合以进行高分辨率图像生成时,会出现一些挑战。主要挑战是 Mamba 的因果序列建模与图像的二维 (2D) 数据结构之间的不匹配。Mamba 架构是为序列信号的一维 (1D) 因果建模而设计的,不能直接用于建模二维图像 tokens。一个简单的解决方案是使用光栅扫描顺序将 2D 数据转换为 1D 的序列。但是,它将每个位置的感受野限制为只有光栅扫描顺序中的先前位置。此外,在光栅扫描顺序中,当前行的结尾后面是下一行的开始,它们之间不共享空间连续性。第二个挑战是,尽管 Mamba 具有高效推理的优势,但在高分辨率图像上训练基于 Mamba 的扩散模型的训练代价依然昂贵。

为了缓解第一个挑战,本文作者提出了 Diffusion Mamba (DiM),如下图 2 所示,这是一种基于 Mamba 的扩散模型 Backbone,用于高效的高分辨率图像生成。在 DiM 中,作者遵循 Diffusion Transformer 的做法将图像编码为 Patch 特征。然后,作者使用 Mamba 架构作为 Backbone 来建模特征。为了避免补丁之间的单向因果关系并赋予每个 token 全局感受野,作者设计了 Mamba Block 交替执行四个扫描方向。此外,作者在扫描顺序相邻的两个标记之间插入 learnable padding tokens,但在空间域中不相邻,从而允许 Mamba 块识别图像边界并避免误导序列模型。作者还将 3×3 Depth-Wise Convolution 添加到网络的输入层和输出层,以增强生成图像的局部相干性。此外,作者在浅层和深层之间添加了长跳跃连接[1][3]。基于 Transfo,以将低级信息传播到高级特征,这一点已经被证明有利于扩散模型中的像素级预测目标。

图2:DiM 框架:框架的输入是 noisy image 或者 latent,具有 time step 和 class condition

为了解决 DiM 高分辨率图像生成训练效率的挑战,作者探索了资源高效的方法,使得在低分辨率图像上预训练的 DiM 模型可以完成高分辨率的图像生成。作者首先观察到在低分辨率图像上预训练的 DiM 可以为高分辨率图像生成提供合理的先验。因此,作者探索了 "weak-to-strong" 训练策略[7]。作者首先在低分辨率图像上训练 DiM,然后使用预训练模型作为初始化来有效地微调高分辨率图像。该策略大大降低了高分辨率图像生成的训练时间成本。作者还探索了 training-free 的上采样方法,使 DiM 进一步适应更高分辨率的图像,而无需进一步微调。

1.2 状态空间模型

1.3 Diffusion Mamba 架构

Mamba 主要用于处理一维输入,因此 Mamba 很难在没有任何修改的情况下学习图像的二维数据结构。因此,作者提出了一些新的架构设计,使 DiM 能够处理空间结构。

总体架构

如图2所示,DiM 框架可以处理有噪声的二维 (2D) 输入,比如图像或者 latent 的特征,同时需要输入 time step 和 class condition。这种噪声输入可以被视为由对应于输入时间步长的特定高斯噪声级别扰动的干净信号。噪声输入首先被分成 2D Patches,每个 Patch 可以通过全连接层转换为高维特征向量。接下来,这些 Patches 被送入 3×3 Depth-Wise Convolution 层,其中局部信息被注入到 Patches 中。Patches 也在行和列的末尾用可学习的 tokens 填充,允许模型在一维顺序扫描期间感知二维空间结构。然后,使用图2所示的四个扫描模式之一,将 Patch tokens 展平为 Patch 序列。time step 和 class condition 也通过全连接层转换为 tokens,然后附加到序列中。随后,序列被送入 Mamba Blocks 进行扫描。此外,作者还在浅层和深层之间添加了长跳跃连接,以将低级信息传播到高级特征,这也被证明有利于扩散模型中的像素级预测目标。

扫描模式

全局感受野对于本文模型有效地捕获图像中的空间结构至关重要。在单个光栅扫描方向上扫描图像 Patches 会导致单向有限的感受野。例如,左上角的第一个扫描 Patch 永远不会聚合来自其他 Patch 的信息。为了使每个 Patch 具有全局感受野,作者在不同的模型块中采用了不同的扫描模式。具体来说,如图2所示,在第1个 Block 中,采用行主扫描,即逐行扫描图像补丁序列,每一行从左到右水平扫描,然后移动到下一行。在第2个 Block 中,反转序列顺序并以相同的方式扫描序列。在随后的 Block 中,作者以正向和反向顺序执行列主扫描。在遍历所有扫描模式后,在下一个模型块中再次循环它们。

可学习的 padding token

图像的空间结构的学习可能会被光栅扫描破坏。具体来说,当将图像展平为 Patch 序列时,图像一行中最右边的补丁变得与第二行最左边的补丁相邻。然而,这两个特征向量所代表的内容可能存在很大差异。因为图像有固有的连续性和空间结构,但是扫描的方式与这种结构相矛盾,从而阻碍了学习的过程。为了缓解这个问题,作者在每一行或者每一列的末尾增加可学习的 padding token,使得模型意识到 End-Of-Line (EOL)。

轻量级局部特征增强

图像的局部结构会会被扫描的序列化过程所破坏。比如在行主扫描中, 行 i 和列 j 处的 Patch 不再与行 i+1 和列 j 处的 Patch 相邻。此外, 由于 Mamba 专为极端的效率优化而设计, 因此选择通过在网络的开头和结尾添加几个轻量级模块来增强局部结构, 而不改变 Mamba Block 本身。

具体来讲,作者引入了两个 3×3 Depth-Wise Convolution 层。在将 tokens 输入给 Mamba Block 之前,在 Patchify 层之后插入一个卷积层。在 Unpatchify 和输出层之前,在所有 Mamba Block 之后插入另一个卷积层。这些轻量级的 Depth-Wise Convolution 层为 DiM 提供了对 2D 局部连续性的认识。

1.4 训练和推理策略

尽管推理效率不错,但在高分辨率图像上训练 DiM 仍需要大量的时间和计算资源。

"Weak-to-strong" 训练和微调

从头开始训练高分辨率图像的扩散模型需要大量的时间和计算资源。作者观察到,在低分辨率图像上预训练的 DiM 可以为高分辨率训练提供粗略的初始化,如下图3所示。因此,作者考虑了一种 "Weak-to-strong" 的训练策略,在低分辨率图像上从头开始训练 DiM 模型,然后对更高的分辨率进行微调。在微调期间,我们将图像的长度和宽度提高2倍。该策略大大降低了使用 DiM 训练高分辨率图像生成器的计算成本。

图3:在 256×256 图像上训练的模型生成的初始 512×512 图像,在微调之前没有任何额外的技术

Training-free 上采样

1.5 实验设置

如下图4所示为 DiM 模型的3个版本,模型大小不同。

图4:DiM 模型的3个版本

输入大小设置为 32×32,没有 image auto-encoder[10]。Mamba Block 的超参数作者遵循标准设置[6]。将 ImageNet 和 CIFAR 上训练的 DiM 的 Patch Size 设置为 2×2。

所有的训练实验都在 8 A100-80G 上执行。继之前的工作 U-ViT 之后,作者使用相同的 DDPM scheduler、预训练的图像自动编码器和 DPM-Solver。使用随机翻转作为数据增强。学习率设置为 2×10−4。还使用速率为0.9999 的 EMA。

评价指标为 FID-50K[11],数据集为 CIFAR 和 ImageNet。作者还使用 classifier-free guidance 进行评估,计算 FID 的指导权重与 U-ViT 中的指导权重相同。考虑到有限的 GPU 资源,在 ImageNet 256×256 上进行预训练时,将 DiM-Large 和 DiM-Huge 的 Batch Size 设置为 1024 和 768。在 ImageNet 512 × 512 上微调 DiM-Huge 时,将 Batch Size 设置为 240,梯度累积。

在 ImageNet 上,作者首先在 256×256 分辨率下预训练超过 300K iterations 的 DiM 模型。然后,以 512×512 的分辨率微调预训练模型。为了在不增加训练成本的情况下实现更高的分辨率,作者进一步使用 training-free 的上采样技术生成 1024×1024 和 1536×1536 图像,DiM-Huge 在 512×512 分辨率下训练。

1.6 效率分析

作者检查 DiM 的效率,并将其与 Transformer 主干进行比较。单个选择性扫描比 FlashAttention V2[12]更高效。然而,为了保持相似数量的参数,标准 Mamba 的 Block 数是 Transformer 的两倍。扫描的加倍增加了计算复杂度。此外,作者提出的包括扫描模式的切换在内的模块也会造成轻微的延迟。为了比较 Transformer 和 Mamba 在图像生成方面的实际效率,作者在单个 H800 GPU 上进行了实验。

作者在图5中展示了本文模型,U-ViT 和 Mamba Baseline 的推理速度。这些模型具有相似的参数量 (0.9B) 和相同的 2×2 Patch Size。可以看到,原始的 Mamba 基线和 DiM 在分辨率低于 1024×1024 的情况下比优化良好的基于 Transformer 的模型慢。然而,在分辨率高于 1280×1280 情况下,DiM 比 Transformer 更快,这要归功于其线性复杂度。

图5:不同分辨率的模型的推理速度。每个模型有 0.9 亿个参数。将速度表示为 iterations per millisecond 的对数

而且,DiM 的效率仅略低于 Mamba Baseline,这表明作者添加到原始 Mamba 的设计使 Mamba 适应 2D 图像,但不会造成较大的额外计算成本。

1.7 实验结果

ImageNet 数据集生成质量

作者选择一组生成的图像进行可视化。结果表明,在 ImageNet 上预训练的 DiM-Huge 可以生成高质量的 256×256 图像,如图 6(b) 所示。本文模型在分辨率为 512×512 的 ImageNet 上微调的模型也显示出出色的性能,如图 6(a) 所示。

图6:DiM-Huge 在 cfg=4.0 时生成的结果

可以使用在 512×512 ImageNet 上训练的模型直接生成 1024×1024 和 1536×1536 图像。如图7所示,即使分辨率增加到训练的3倍,本文模型仍然能够生成具有上采样引导的视觉上吸引人的图像。

图7:在 512×512 图像上训练的 DiM-Huge 生成的高分辨率图像

ImageNet 256×256 预训练数值结果

考虑到有限的计算资源和时间,作者只能在最多 319 million 图像样本上训练模型。作者将 DiM 与其他基于 Transformer 和 SSM 的扩散模型进行了比较,如下图8所示。值得注意的是,在对 319 million 个图像样本进行训练后,DiM-Huge 在 FID-50K 上可以达到 2.40 的分数。在使用 U-ViT (319M vs 500M) 63% 的训练数据的情况下,本文模型的性能与其他基于 Transformer 的扩散模型相当,即在 FID-50K 上仅差约 0.1。此外,与 DiffuSSM-XL 相比,基于 Mamba 的扩散模型的 GFLOPs 要小得多,即 DiM 需要更少的资源进行推理。

图8:在分辨率为 256×256 的 ImageNet 上进行预训练结果

ImageNet 512×512 微调数值结果

在 512×512 图像样本进行训练需要大量的计算资源。此外,这种较大的分辨率在训练和推理过程中造成了不可忽略的延迟,如图9所示。因此,作者没有从头开始训练,而是从在 ImageNet 256×256 上进行预训练之后的模型为初始化,来微调本文的 DiM-Huge。作者只使用了 U-ViT 的 512×512 训练数据的 3% (15M vs 500M),DiM-Huge 就达到了 3.94 FID50K。尽管仍然远未达到最佳性能,但 DiM-Huge 能够产生视觉上吸引人的 512×512 图像,如图7(a)所示。

图9:ImageNet 512×512 上各种模型的结果。表中的预训练表示以 256×256 分辨率训练的 DiM

CIFAR-10 的实验结果如下图10所示。本文方法可以与其他具有相似参数量的方法相比实现相当的性能。

图10:CIFAR-10 实验结果

消融实验结果

作者在 CIFAR-10 数据集上进行了消融实验,FIDs 结果如图11所示。其中第一行包含最佳性能模型的结果,其他行的性能对应于没有某些组件的模型。

图11:CIFAR-10 消融实验结果。MS:多次扫描模式。LSC:长 skip-connection。PT:padding token。LFE:局部特征增强

根据第一行和最后一行,多次扫描模式会对结果有帮助,说明全局感受野的重要性。

作者还发现长距离 skip-connection 有助于收敛。

此外,卷积和 padding token 也有助于提升性能。

参考

  1. ^abU-Net: Convolutional Networks for Biomedical Image Segmentation

  2. ^Scalable Diffusion Models with Transformers

  3. ^abAll are worth words: A vit backbone for diffusion models

  4. ^Pixart-alpha: Fast training of diffusion transformer for photorealistic text-to-image synthesis

  5. ^Scaling rectified flow transformers for high-resolution image synthesis

  6. ^abMamba: Linear-Time Sequence Modeling with Selective State Spaces

  7. ^PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation

  8. ^ImageNet: A large-scale hierarchical image database

  9. ^Upsample guidance: Scale up diffusion models without training

  10. ^Auto-Encoding Variational Bayes

  11. ^GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium

  12. ^FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

  • 39
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值