ZigMa: A DiT-style Zigzag Mamba Diffusion Model

ZigMa: DiT风格之字形Mamba扩散模型

在这里插入图片描述

论文链接:https://arxiv.org/abs/2403.13802

项目链接:https://taohu.me/zigma/

Abstract

扩散模型长期以来一直受到可扩展性和二次复杂度问题的困扰,特别是在基于Transformer的结构中。在本研究中,我们的目标是利用称为Mamba的状态空间模型的长序列建模能力来扩展其对可视化数据生成的适用性。首先,我们确定了目前大多数基于Mamba的视觉方法中的一个关键疏忽,即在Mamba扫描方案中缺乏对空间连续性的考虑。其次,在此基础上,我们介绍了Zigzag Mamba,这是一种简单、即插即用、最小参数负担、DiT风格的解决方案,与基于Transformer的基线相比,它优于基于Mamba的基线,并展示了更高的速度和内存利用率。最后,我们将Zigzag Mamba与随机插值框架相结合,研究了该模型在大分辨率视觉数据集(如FacesHQ 1024 × 1024和UCF101、MultiModal-CelebA-HQ和MS COCO 256 × 256)上的可扩展性。

1 Introduction

扩散模型在图像处理[67]、视频分析[40]、点云处理[87]和人体姿态估计[29]等各种应用中都取得了重大进展。其中许多模型建立在潜在扩散模型(Latent Diffusion models, LDM)[67]之上,后者通常基于UNet主干。然而,可扩展性仍然是LDM的一个重大挑战[41]。最近,基于Transformer的结构因其可扩展性[9,65]和在多模态训练中的有效性[10]而受到欢迎。值得注意的是,基于Transformer的结构DiT[65]甚至有助于OpenAI增强高保真视频生成模型SORA[64]。尽管通过诸如窗口[58]、滑动[12]、稀疏化[18,46]、散列[19,74]、Ring注意力[14,53]、Flash注意力[21]或它们的组合[8,97]等技术来减轻注意力机制的二次复杂度,但它仍然是扩散模型的瓶颈。

另一方面,状态空间模型[31,32,35]在长序列建模方面显示出巨大的潜力,与基于Transformer的方法竞争。几个已经提出了一些方法[27,30,32,69]来增强状态空间模型的鲁棒性[92]、可扩展性[30]和效率[32,33]。其中,一种名为Mamba的方法[30]旨在通过高效的并行扫描和其他依赖数据的创新来缓解这些问题。然而,Mamba的优势在于一维序列建模,将其扩展到二维图像是一个具有挑战性的问题。先前的研究[57,96]提出了通过计算机层次结构(如行-列-主顺序)直接平坦化二维标记,但这种方法忽略了空间连续性,如图1所示。其他工作[54,60]考虑在单个Mamba块中的各个方向,但这引入了额外的参数和GPU内存负担。在本文中,我们旨在强调Mamba空间连续性的重要性,并提出了几种直观和简单的方法,通过在图像中结合基于连续性的归纳偏置,使Mamba块应用于二维图像。通过对三维序列进行时空分解,将这些方法推广到三维。

在这里插入图片描述

最后,Stochastic Interpolant[3]提供了一个更广义的框架,可以统一各种生成模型,包括Normalizing Flow[16]、扩散模型[39,70,72]、Flow matching[4,51,56]和Schrödinger Bridge[52]。以前,一些作品[61]在相对较小的分辨率上探索随机插值,例如256×256, 512×512。在这项工作中,我们的目标是在更复杂的场景中探索它,例如1024 × 1024分辨率,甚至在视频中。

综上所述,我们的贡献如下:首先,我们确定了将Mamba块从一维序列建模推广到二维图像和三维视频建模的空间连续性的关键问题。基于这一见解,我们提出了一个简单的,即插即用,零参数的范式,命名为Zigzag Mamba(ZigMa),利用空间连续性来最大限度地结合视觉数据的归纳偏置。其次,我们将方法从二维扩展到三维,通过分解空间和时间序列来优化性能。其次,我们在扩散模型的范围内提供了围绕Mamba块的综合分析。最后,我们证明了我们设计的Zigzag Mamba优于相关的基于Mamba的基线,代表了大规模图像数据(1024×1024)和视频上随机插值的首次探索。

2 Related Works

Mamba。一些研究[82,83]已经证明状态空间模型在一定条件下具有普遍逼近能力。Mamba作为一种新型的状态空间模型,具有高效建模长序列的优越潜力,在医学成像[60,68,86,89]、图像恢复[34,95]、图形[11]、NLP word byte[80]、表格数据[2]、点云[49]、图像生成[25]等多个领域得到了探索。其中与我们关系最密切的是VisionMamba[57,96]、S4ND[63]和Mamba-ND[48]。VisionMamba[57,96]在判别性任务中使用双向SSM,这导致了很高的计算成本。我们的方法在生成模型中应用了一个简单的替代Mamba扩散。S4ND[63]在Mamba的推理过程中引入了局部卷积,超越了仅使用一维数据。Mamba-ND[48]在判别任务中考虑了多维度,在单个块内使用各种扫描。相比之下,我们的重点是在网络的每一层分布扫描复杂性,从而在零参数负担的情况下最大限度地结合视觉数据的归纳偏置。

扩散模型中的主干。扩散模型主要采用基于Unet[39,67]和基于ViT[9,65]的主干。UNet以高内存需求而闻名[67],而ViT则受益于可扩展性[17,22]和多模式学习[10]。然而,ViT的二次复杂度限制了视觉标记处理,促使人们研究如何缓解这一问题[12,21,84]。我们的工作受到Mamba[30]的启发,探索了一种基于SSM的模型作为通用扩散主干,保留了ViT的模态不可知和顺序建模优势。同时,DiffSSM[90]专注于S4模型中的无条件条件和类条件作用[32]。DIS[25]主要在相对较小的尺度上探索状态空间模型,这并不是我们的工作重点。我们的工作与他们的工作有很大的不同,因为它主要关注使用Mamba块的主干设计,并将其扩展到文本调节。此外,我们将该方法应用于更复杂的视觉数据。

扩散模型中的SDE和ODE。基于分数的生成模型领域包含了基础工作的重要贡献,例如Song等人[71]提出的基于朗格万动力学的分数匹配(SMLD),以及Ho等人提出的带有去噪分数匹配(DDPM)的扩散模型[39]。这些方法在随机微分方程(SDE)的框架内运行,这是Song等人[72]研究中进一步完善的概念。最近的研究进展,如Karras等人[42]和Lee等人[47]所示,展示了使用常微分方程(ODE)采样器进行扩散SDE的有效性,与需要离散化扩散SDE此外,在Flow Matching方法学[51]和Rectified Flow架构[55]领域,SMLD和DDPM都是在概率流ODE框架[72]的不同路径下出现的专门实例。这些模型通常利用线性插值的速度场参数化,这一概念在随机插值框架中得到了更广泛的应用[3],随后的推广扩展到流形设置[13]。SiT模型[61]仔细研究了采样和训练背景下插值方法之间的相互作用,尽管是在较小的分辨率(如512 × 512)的背景下。我们的研究努力将这些见解扩展到更大的范围,专注于1024 × 1024的2D图像和3D视频数据的泛化能力。的传统方法相比,大大降低了采样成本。

3 Method

在本节中,我们首先提供关于状态空间模型的背景信息[31,32,35],并特别关注被称为Mamba的特殊情况[30]。然后,我们强调了Mamba框架内空间连续性的关键问题,并基于这一见解,我们提出之字形Mamba。这种增强旨在通过结合二维数据固有的连续性归纳偏置来提高二维数据建模的效率。此外,我们在Mamba块上设计了一个基本的交叉注意块来实现text-conditioning。随后,我们建议将该方法扩展到3D视频数据,将模型分解为空间和时间维度,从而简化建模过程。最后,我们介绍了用于训练和采样的随机插值的理论方面,这是我们网络架构的基础。

3.1 背景:状态空间模型

状态空间模型(State Space Models, SSM)[31,32,35]已经被证明可以在理论上和经验上处理具有线性缩放w.r.t序列长度的远程依赖[33]。线性状态空间模型的一般形式为:
x ′ ( t ) = A ( t ) x ( t ) + B ( t ) u ( t ) y ( t ) = C ( t ) x ( t ) + D ( t ) u ( t ) \begin{aligned} x^{\prime}(t) & =\mathbf{A}(t) x(t)+\mathbf{B}(t) u(t) \\ y(t) & =\mathbf{C}(t) x(t)+\mathbf{D}(t) u(t) \end{aligned} x(t)y(t)=A(t)x(t)+B(t)u(t)=C(t)x(t)+D(t)u(t)
通过隐式N-D潜在状态序列 x ( t ) ∈ R n x(t) \in \mathbb{R}^n x(t)Rn,将1-D输入序列 u ( t ) ∈ R u(t) \in \mathbb{R} u(t)R映射到1-D输出序列 y ( t ) ∈ R y(t) \in \mathbb{R} y(t)R。具体而言,深度SSM寻求在神经序列建模架构中使用这种简单模型的堆栈,其中每层的参数 A , B , C \mathbf{A}, \mathbf{B}, \mathbf{C} A,B,C D \mathbf{D} D可以通过梯度下降来学习。

最近,Mamba[30]在保持计算效率的同时,通过放宽SSM参数的时不变约束,在很大程度上提高了SSM的灵活性。通过采用高效的并行扫描,Mamba减轻了重复的顺序性质的影响,而融合GPU操作则消除了实现扩展状态的要求。在本文中,我们专注于探索Mamba在扩散模型中的扫描方案,以最大限度地利用多维视觉数据中的归纳偏置。

3.2 扩散主干:之字形Mamba

DiT-Style网络。我们选择使用AdaLN的ViT框架[65],而不是跳跃层的U-ViT结构[9],因为ViT在文献[10,17,64]中已被验证为可扩展结构。考虑到前面提到的几点,它为图4所示的Mamba网络设计提供了信息。这个设计的核心部分是Zigzag形扫描,这将在下面的段落中解释。

在这里插入图片描述

Mamba之字形扫描。先前的研究[81,90]在SSM框架内使用了双向扫描。这种方法已经扩展到包括额外的扫描方向[54,57,91],以考虑二维图像数据的特征。这些方法沿着四个方向展开图像patch,产生四个不同的序列。每个序列随后通过每个SSM一起处理。然而,由于每个方向可能有不同的SSM参数(A、B、C和D),因此增加方向的数量可能会导致内存问题。在这项工作中,我们研究了将Mamba的复杂性分摊到网络的每一层的潜力。

我们的方法围绕着token重新排列的概念,然后将它们馈送到前向扫描块。对于来自层 i i i的给定输入特征 z i z_i zi,重排后的前向扫描块的输出特征 z i + 1 z_{i+1} zi+1可以表示为:
z Ω i = arrange ⁡ ( z i , Ω i ) z ‾ Ω i = scan ⁡ ( z Ω ) z i + 1 = arrange ⁡ ( z ‾ Ω i , Ω ˉ i ) \begin{align} \mathbf{z}_{\Omega_i} & =\operatorname{arrange}\left(\mathbf{z}_i, \Omega_i\right) \tag{1}\\ \overline{\mathbf{z}}_{\Omega_i} & =\operatorname{scan}\left(\mathbf{z}_{\Omega}\right) \tag{2}\\ \mathbf{z}_{i+1} & =\operatorname{arrange}\left(\overline{\mathbf{z}}_{\Omega_i}, \bar{\Omega}_i\right)\tag{3} \end{align} zΩizΩizi+1=arrange(zi,Ωi)=scan(zΩ)=arrange(zΩi,Ωˉi)(1)(2)(3)
Ω i Ω_i Ωi表示第 i i i层的1D排列,将patch token的顺序重新排列 Ω i Ω_i Ωi Ω i Ω_i Ωi Ω ˉ i \bar{Ω}_i Ωˉi表示相反的操作。这确保了 z i z_i zi z i + 1 z_{i+1} zi+1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值