改进DiT!清华大学提出Inf-DiT:超高分辨图像无限生成

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

点击进入—>【Mamba/多模态/扩散】交流群

添加微信:CVer5555,小助手会拉你进群!

扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!

177f59f96c7bbc9a2150077c34c53b71.jpeg

转载自:极市平台 | 作者:科技猛兽

导读

 

综合实验表明,Inf-DiT 在生成超高分辨率图像方面取得了 SOTA 性能。与常用的 UNet 结构相比,Inf-DiT 在生成 4096×4096 图像时可以节省超过5倍显存。 

本文目录

1 Inf-DiT:Diffusion Transformer 任意分辨率上采样
(来自清华大学,唐杰团队)
1 Inf-DiT 论文解读
1.1 超高分辨率图像生成问题的挑战:GPU 显存需求
1.2 单向块注意力机制
1.3 O(N) 显存消耗的推理过程
1.4 Inf-DiT 架构
1.5 全局和局部一致性
1.6 实验结果

太长不看版

扩散模型在图像生成方面表现出了很显著的性能。然而对于生成超高分辨率的图像 (比如 4096 ×4096) 而言,由于其 Memory 也会二次方增加,因此生成的图像的分辨率通常限制在 1024×1024。在这项工作中。作者提出了一种单向块注意力机制,可以在推理过程中自适应地调整显存开销并处理全局依赖关系。在这个模块的基础上,作者使用 DiT 的架构,并逐渐执行上采样,最终开发了一个无限的超分辨率模型 Inf-DiT,能够对各种形状和分辨率的图像进行上采样。综合实验表明,Inf-DiT 在生成超高分辨率图像方面取得了 SOTA 性能。与常用的 UNet 结构相比,Inf-DiT 在生成 4096×4096 图像时可以节省超过5倍显存。

43e65b73c1f8f8867638d57843d59d92.png
图1:基于 SDXL、DALL-E 3 和真实图像,选择出的 Inf-DiT 超高分辨率上采样示例

本文做了哪些具体的工作

  1. 提出了单向块注意力机制 (Unidirectional Block Attention,UniBA) 算法,在推理过程中将最小显存消耗从 降低到 , 其中 表示边长。该机制还能够通过调整并行生成的块数量、在显存和时间开销之间进行权衡来适应各种显存限制。

  2. 基于这些方法,训练了一个图像上采样扩散模型 Inf-DiT,这是一个 700M 的模型,能够对不同分辨率的和形状图像进行上采样。Inf-DiT 在机器 (HPDV2 和 DIV2K 数据集) 和人工评估中都实现了最先进的性能。

  3. 设计了多种技术来进一步增强局部和全局一致性,并为灵活的文本控制提供 Zero-Shot 的能力。

1 Inf-DiT:Diffusion Transformer 任意分辨率上采样

论文名称:Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer (Arxiv 2024.03)

论文地址:

https://arxiv.org/pdf/2405.04312

项目地址:

https://github.com/THUDM/Inf-DiT

1.1 超高分辨率图像生成问题的挑战:GPU 显存需求

近年来,扩散模型取得了快速发展,显着推动了图像生成和编辑领域的发展。尽管取得了进步,但仍然存在一个关键的限制:现有图像扩散模型生成的图像的分辨率通常被限制在 1024×1024 像素或更低,这对生成超高分辨率图像提出了重大挑战,这在包括复杂的设计项目、广告和海报和墙壁纸的创建等各种实际应用中是必不可少的。

生成高分辨率的常用方法是 Cascaded Generation,它首先生成低分辨率图像,然后应用多个上采样模型逐步提高图像的分辨率。这种方法将高分辨率图像的生成分解为多个子任务。基于前一阶段产生的结果,后期的模型只需要执行局部的生成。在级联结构的基础上,DALL-E2[1]和 Imagen[2]都可以有效地生成 1024×1024 分辨率的图像。

上采样到更高分辨率的图像的最大挑战是关于 GPU 显存需求。例如,如果使用广泛采用的 U-Net 架构,例如 SDXL[3]进行图像推理 (见下图2),可以观察到显存消耗随着分辨率的增加而急剧增加。具体来说,如果生成 4096×4096 分辨率的图像,其包含超过 16 亿个像素需要超过 80GB 的显存,超过了标准 RTX 4090 或 A100 显卡的容量。此外,用于高分辨率图像生成的训练模型的过程加剧了这些需求,因为它需要额外的显存来存储梯度、优化器状态等。LDM[4]通过利用变分自动编码器 (Variational Autoencoder,VAE) 压缩图像并在更小的 Latent Space 中生成图像来减少显存消耗。然而,过高的压缩比会大大降低生成的质量,对显存消耗的减少造成了严重的限制。

d3dc807c647cdd5d43827c377a6a3bb7.png
图2:本文模型和 SDXL 架构之间不同分辨率的推理期间显存使用的比较

1.2 单向块注意力机制

作者观察到生成超高分辨率图像的关键障碍是显存限制。随着图像的分辨率的增加,网络中相应的 hidden states 的大小呈二次方的复杂度扩展。例如,1层中形状为 2048×2048×1280 的单个 hidden state 需要 20GB 的显存,这使得很难生成非常大的图像。如何避免将整个图像的 hidden state 存储在内存中成为关键的问题。

作者的主要想法是将图像 划分为 Blocks , 其中36a2836a8883987c4d9aba78293e8a17.png是块大小, 。当图像被送入网络时, Block 的大小和分辨率可能会改变, 但 Block之间的布局和相对位置关系保持不变。如果有一种方法可以应用顺序批量生成 Blocks,其中每个 Batch 同时生成 Blocks 的子集,则只需要同时在内存中保留少量 Blocks 的隐藏状态,就可以生成超高分辨率图像。

本文的方法单向块注意力 (Unidirectional Block Attention, UniBA) 如下图3所示。对于每个层,每个 Block 直接依赖于3个一阶相邻的 Block:顶部的 Block、左侧和左上角的 Block。例如,如果采用 Diffusion Transformer (DiT) 架构,Block 之间的依赖关系是注意力操作,每个 Block 的 Query 向量与4个 Block 的 Key,Value 向量交互:位于其左上角和自身的3个 Blocks,如图3所示。

3233262f0fcd16e93dec07e847699e6b.png
图3:左侧:单向块注意力。在我们的实现中,每个 Block 直接取决于每一层的3个 Blocks:左上角的块、左侧和顶部的 Block;右侧:Inf-DiT 的推理过程。Inf-DiT 根据内存大小每次生成 n×n 个 Block。在这个过程中,只有后续块所依赖的块的 KV-cache 存储在内存中

Transformer 中的 UniBA 过程可以表述为:

31b852400f8158ec985ad38766308b48.png

其中,e89622008f9667e108660dab5a6f6838.png 是第13000c64a06e586e38bc606ced891722.png层, 第78dc243c6d6ae23f09af807afd8ba3fc.png行, 第4dd6dcff119d4e5dd85ee02e7255d33a.png列的 hidden state,d1312a24308576f64253886f104930f3.png是 Block-level 的相对位置编码。

1.3 O(N) 显存消耗的推理过程

尽管本文的方法可以按顺序生成每个 Block,但它与自回归的生成模型不同。在自回归的生成模型中,下一个 Block 取决于前面 Blocks 的最终输出。本文方法可以并行生成任意数量的块。基于这一特性,作者实现了一个简单但有效的推理过程。如图3所示,一次性生成 n×nn×n 个块,从左上角到右下角。在生成一组块后,丢弃不再使用的隐藏状态,并将新生成的 KV-cache 附加到显存中。

可以很容易地证明, 在此过程中保留在显存中的 Block KV-cache 的数量总是 。假设模型在生成单个 Block 时所需的空间为142d5f10e2a1ca3017c31656343d7aa5.png, 一个 Block 的 KV-cache 的空间为50d1c179e7f8d57925ae9e7bc981f78a.png, 其他基本空间消耗 (例如存储原始输入图像) 为C, 则推理过程的最大空间使用为 。当efab19ae7a36d3e5cb6415ad155e99d4.png远小于9c497b09e81496ba01f0c76a5c48b2a4.png时, 内存消耗与56c46469325f5b1ca4eb3afaba5068fd.png成正比。

在实际应用中, 尽管对于不同的57d21518a06f097d769cf5235dbe5794.png值, 生成图片的总 FLOPs 是恒定的, 但是受算子初始化时间与显存分配时间的影响, 当ef4a0c00a127ff5a78fc04e0eb3e11c0.png增加时, 生成时间减少。因此, 最好选择内存限制允许的最大f713af5c2a19208d7947e9e24c400912.png

1.4 Inf-DiT 架构

如下图4所示是 Inf-DiT 架构,它基于 DiT[5]。与基于卷积的结构 (如 U-Net[6]) 相比,DiT 仅利用注意力作为 Patch 之间的交互机制,可以方便地实现 UniBA。为了适应 UniBA,提高上采样的性能,作者做了如下几个修改和优化。

c36adfde436e98afd086eee527e2f155.png
图4:Inf-DiT 架构

模型输入

Inf-DiT首先将输入图像划分为多个不重叠的 Blocks,进一步划分为 Patches。与 DiT 不同,考虑到颜色偏移和细节损失等压缩损失,Inf-DiT 的修补是在 RGB 像素空间中进行的,而不是在 Latent Space。在超分辨率72e1a5339ee9e6ba54b4aef4c24ec444.png次的情况下,Inf-DiT 首先将低分辨率 RGB 图像条件上采样1709e38a58d52d767871f58741c638da.png倍,然后将其与扩散的噪声输入在特征维数上 Concat 起来,然后将其输入到模型中。

位置编码

最近 LLM 的结果表明,与绝对位置编码相比,相对位置编码在捕获词位置相关性方面更有效。作者参考了 Rotary 位置编码 (RoPE)[7]的设计,它在长上下文生成中表现良好,并将其适配到二维形式的图像生成中。具体来说,作者将隐藏状态的通道分成两半,一个用于编码207f1a70eaf2ee8220b1f78bb955c0ac.png坐标,另一个用于编码43075a9c3607d5126f9460a0d1946091.png坐标,分别使用 RoPE。

作者创建了一个足够大的 Rotary 位置编码表。为了确保训练过程中模型可以看到位置编码表的所有部分, 作者使用随机起点: 对于每个训练图像, 为图像的左上角随机分配一个位置 , 而不是默认的 。

此外, 考虑到同一个 Block 内和不同 Block 之间的交互差异, 作者还引入了 Block-level 的相对位置编码45cc908962c0b1b39bbfe9f9a37ca7c2.png, 它根据注意前的相对位置分配不同的 learnable embedding。

1.5 全局和局部一致性

使用 CLIP Image Embedding 针对全局一致性

低分辨率 (LR) 图像中的全局语义信息,如艺术风格和物体材料,在上采样过程中起着至关重要的作用。然而,与文生图像模型相比,上采样模型还有一个额外的任务:理解和分析 LR 图像的语义信息,大大增加了模型的负担。在没有文本数据进行训练时尤其具有挑战性,因为高分辨率图像很少具有高质量的配对文本,这使得模型的这些方面变得困难。

受 DALL-E2[1]的启发,作者利用预训练的 CLIP[8]中的图像编码器从低分辨率图像中提取 Image Embedding17c89f7331cf60afb7812332b6ca9e28.png ,称之为语义输入。由于 CLIP 是在互联网上海量的图像-文本对上训练的,其图像编码器可以有效地从低分辨率图像中提取全局信息。作者将全局语义嵌入添加到 Diffusion Transformer 的 time Embedding 中,并将其输入到每一层,使模型能够直接从高级语义信息中学习。

全局语义嵌入的另一个有趣优势是,使用 CLIP 中的对齐图像-文本 Latent Space,即使本文模型没有在任何图像-文本对上进行训练, 也可以使用文本来指导生成。给定一个正提示a86d3f2cc82405f9ed4c19d0b91272fb.png 和一个负提示dfbed0ae0113536c1c5bad78b5cef8c8.png, 可以更新图像嵌入:

786be216327940233530fb97c25dcaf9.png

其中,07f1dcbe99438b4f8759d69037bc8bf8.png可以控制指导的强度。在推理过程中, 可以简单地使用a39226555222e718ef47daff1e8a15a8.png 代替4cec750c5678d70a441b06e521eccacd.png 作为全局语义嵌入来进行控制。例如, 为了获得更清晰的结果集, "clear" 和 " blur" 有时会有所帮助。

使用 Nearby LR Cross Attention 针对局部一致性

尽管将 LR 图像与噪声输入 Concat 起来已经为模型学习 LR 和 HR 图像之间的局部对应关系提供了良好的归纳偏差,但仍然可能存在连续性的问题。原因是,对于给定的 LR Block,有几种上采样的可能性,这需要与附近的几个 LR Block 一起分析以选择一种解决方案。假设上采样仅基于其左侧的 LR Block 执行,它可能会选择一个与右侧和下方 LR Block 冲突的 HR 生成解决方案。然后,当将 LR Block 上采样到右侧时,如果模型认为符合其对应的 LR Block 比与左侧的 Block 连续更重要,则会生成一个与先前块不连续的 HR Block。一个简单的解决方案是将整个 LR 图像输入到每个 Block,但当 LR 图像的分辨率也很大时,它的成本太高。

为了解决这个问题,作者引入了 Nearby LR Cross-Attention。在第一层中,每个 Block 对周围的 3×3 LR Block 进行 Cross-Attention,以捕获附近的 LR 信息。实验结果表明,这种方法显着减少了生成不连续图像的概率。值得注意的是,这个操作不会改变我们的推理过程,因为在生成之前知道整个 LR 图像。

1.6 实验结果

训练细节

本文的数据集包括 LAION-5B[9]的一个子集,分辨率高于 1024×1024,美学得分高于 5 的 100000 来自互联网的分辨率墙纸。在训练过程中,作者使用 512×512 分辨率的固定大小的 Image crop。由于上采样只能使用局部信息进行,因此在推理过程中可以直接用于更高的分辨率,这对于大多数生成模型来说并不容易。

数据准备

由于扩散模型生成的图像通常包含残余噪声和各种细节不准确,因此增强上采样模型的鲁棒性以解决这些问题变得至关重要。作者采用类似于 Real-ESRGAN[10]的方法对训练数据中的低分辨率输入图像执行各种退化。

在处理分辨率高于 512 的图像时,有两种替代方法:一种是直接执行随机裁剪,另一种是在执行随机裁剪之前将较短的边调整为 512。虽然直接裁剪方法在高分辨率图像中保留了高频特征,但调整大小后裁剪方法避免了频繁裁剪单个颜色背景的区域,不利于模型的收敛。因此在实践中,作者从这两种处理方法中随机选择裁剪训练图像。

作者将 Block Size 设置为 128,Patch Size 设置为 4,即每张图片被分成 4×4 Blocks,每个 Block 被分成 32×32 Patches。作者使用 EDM[11]框架训练,并将上采样设置为4倍。由于上采样任务更关注图像的高频细节,我们将训练噪声分布的均值和标准差调整为 -1.0 和 1.4。为了解决训练期间的溢出问题,作者采用了 BF16 格式。采用的 CLIP 模型是在 Datacomp 数据集[12]上预训练的 ViT-L/16。由于 CLIP 只能处理分辨率为 224×224 的图像,作者首先将 LR 图像的大小调整为 224×224,然后将它们输入到 CLIP 中。

机器评测

作者对 Inf-DiT 与超高分辨率图像生成任务的最新方法进行了定量比较,Baseline 包含两类高分辨率生成方法:

1) 直接高分辨率图像生成,包括 SDXL、MultiDiffusion[13]、ScaleCrafter[14]。

2) 基于超分辨率技术的高分辨率图像生成,包括 BSRGAN[15]、DemoFusion[16]。

使用 FID[17]来评估超高分辨率生成的质量。为了进一步验证我们模型的超分辨率能力,作者还在经典的超分辨率任务上将其与著名的超分辨率模型进行了基准测试。

超高分辨率生成结果

作者使用 HPDV2 的测试集进行评估。它包含 3200 个 Prompt,分为4类:"Animation", "Concept-art", "Painting", 和 "Photo"。这允许对各种域和样式的模型生成能力进行全面的评估。作者在 2048 和 4096 两个分辨率上面进行测试。对于基于超分辨率的模型,作者首先使用 SDXL 生成 1024×1024 分辨率的图像并在没有文本的情况下对其进行上采样。作者使用 BSRGAN 的 2 倍和 4 倍版本分别生成 2048×2048 和 4096×4096 分辨率的图片。虽然 Inf-DiT 是在上采样 4× 的设置下训练的,但作者发现它可以在较低的上采样倍数下很好地泛化。对于 2048×2048 分辨率,作者直接将 LR 图像的大小从 1024×1024 调整到 2048×2048,并将其与噪声输入拼接起来。

如下图5所示的实验结果显示,本文模型在平均得分上超过了所有竞争对手。这展示了本文模型生成高分辨率细节和全局信息的能力。唯一的例外是 4096×4096 分辨率上的 FID 指标,略微落后于 BSRGAN。本文模型可以应用于所有生成模型,不仅仅是 SDXL。

97595b18a3ca3c1b5308bc5f3bc24d7d.png
图5:HPDV2 数据集上超高分辨率生成方法的定量比较结果
96a75e7f5c9e70f3d5d4684785a73bde.png
图6:2048×2048 分辨率下对不同方法的定性比较
c7f340aeba6fb9ed6457a3ee83f42e49.png
图7:不同方法在 4096×4096 分辨率下的定性比较

超分辨率实验结果

除了生成高分辨率图像的能力外,Inf-DiT 也可以用作经典的超分辨率模型。作者对 DIV2k 验证集进行评估,该数据集包含不同场景下多个真实世界的高分辨率图像。作者将图像退化固定为 4× 下采样的双三次插值。在与固定分辨率模型 LDM 和 StableSR 进行比较之前,作者从高分辨率图像中心裁剪特定的小块作为 ground truth。在整个过程中,作者使用感知 (FID, FIDcrop) 和保真度 (PSNR, SSIM) 指标来确保详细和全面的评估。

实验结果如图8所示,本文模型在所有指标上实现了最先进的性能。这意味着,作为超分辨率模型,Inf-DiT 不仅擅长在任意尺度上执行超分辨率,还擅长在恢复与原始图像非常相似的结果的同时最佳地保留全局和局部信息。

2664f52fcdc89cf8c65396084f764a3c.png
图8:DIV2K 数据集与最先进的超分辨率方法的比较结果

人类评测结果

为了从人类的角度更准确地反映其生成质量,作者进行了人工评估。作者比较了4个模型,对每个模型随机选择十个比较集,每个比较集包含来自四个模型的输出,最终总共有 40 个数据。为了确保公平,作者在每个比较集中随机化模型输出序列。人类评估者被要求根据3个标准评估模型:细节真实性、全局连贯性和与原始低分辨率输入的一致性。每个评估者平均接收 20 组图像。在每个集合中,评估者需要根据3个标准将四个模型生成的图像从最高排名排名最低。

最终收集了 3,600 次比较。如图9所示,本文模型在所有3个标准上都优于其他3种方法。特别值得注意的是,其他3个模型在3个评估标准中至少有一个的排名相对较低,而 Inf-DiT 在所有3个标准上都取得了最高分:细节真实性、全局连贯性和与低分辨率输入的一致性。这表明本文模型是唯一能够同时在高分辨率生成和超分辨率任务中表现出色的模型。

860ff0ffa8e823f032b8f23b8e329eb9.png
图9:人类评估结果

迭代上采样

由于本文的模型可以对任意分辨率的图像进行上采样,因此测试模型是否可以迭代地对自身生成的图像进行上采样是很自然的想法。在这项研究中,作者尝试通过3次迭代上采样从 32×32 分辨率图像上采样 64 倍之后生成 2048×2048 分辨率图像。图 10 展示了两种样本。在第1个样本中,模型在上采样3个阶段后成功地生成了高分辨率图像。它在不同的分辨率上采样中生成不同频率的细节:人脸的轮廓、眼球的形状和个人睫毛。然而,模型很难纠正早期阶段产生的不准确,从而导致错误的积累。在第2个样本中,作者演示了这个问题的一个示例。我们将此问题留给未来的工作。

5306d3b7f6cbe94a252ee76c5b53fc39.png
图10:迭代上采样结果。上:Inf-DiT 可以多次上采样自己生成的图像,并在相应分辨率下生成不同频率的细节;下:在低分辨率 128×128 时未能准确生成之后,后续很难纠正错误

参考

  1. ^abHigh-resolution image synthesis with latent diffusion models

  2. ^Photorealistic textto-image diffusion models with deep language understanding

  3. ^Sdxl: Improving latent diffusion models for high-resolution image synthesis

  4. ^High-resolution image synthesis with latent diffusion models

  5. ^Scalable Diffusion Models with Transformers

  6. ^U-Net: Convolutional Networks for Biomedical Image Segmentation

  7. ^RoFormer: Enhanced Transformer with Rotary Position Embedding

  8. ^Learning Transferable Visual Models From Natural Language Supervision

  9. ^LAION-5B: An open large-scale dataset for training next generation image-text models

  10. ^Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data

  11. ^Elucidating the Design Space of Diffusion-Based Generative Models

  12. ^https://doi.org/10.5281/zenodo.5143773

  13. ^Multidiffusion: Fusing diffusion paths for controlled image generation

  14. ^Scalecrafter: Tuning-free higher-resolution visual generation with diffusion models

  15. ^Designing a practical degradation model for deep blind image super-resolution

  16. ^Demofusion: Democratising high-resolution image generation with no $$$

  17. ^GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium

何恺明在MIT授课的课件PPT下载

在CVer公众号后台回复:何恺明,即可下载本课程的所有566页课件PPT!赶紧学起来!

CVPR 2024 论文和代码下载

在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集

Mamba、多模态和扩散模型交流群成立

 
 
扫描下方二维码,或者添加微信:CVer5555,即可添加CVer小助手微信,便可申请加入CVer-Mamba、多模态学习或者扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF、3DGS、Mamba等。
一定要备注:研究方向+地点+学校/公司+昵称(如Mamba、多模态学习或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

 
 
▲扫码或加微信号: CVer5555,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集近万人!

▲扫码加入星球学习
 
 
▲点击上方卡片,关注CVer公众号
整理不易,请赞和在看
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值