超分辨图像无限生成!清华Inf-DiT:任意分辨率上采样

作者 | 科技猛兽  编辑 | 极市平台

点击下方卡片,关注“自动驾驶之心”公众号

戳我-> 领取自动驾驶近15个方向学习路线

>>点击进入→自动驾驶之心扩散模型技术交流群

本文只做学术分享,如有侵权,联系删文

导读

 

综合实验表明,Inf-DiT 在生成超高分辨率图像方面取得了 SOTA 性能。与常用的 UNet 结构相比,Inf-DiT 在生成 4096×4096 图像时可以节省超过5倍显存。

本文目录

1 Inf-DiT:Diffusion Transformer 任意分辨率上采样
(来自清华大学,唐杰团队)
1 Inf-DiT 论文解读
1.1 超高分辨率图像生成问题的挑战:GPU 显存需求
1.2 单向块注意力机制
1.3 O(N) 显存消耗的推理过程
1.4 Inf-DiT 架构
1.5 全局和局部一致性
1.6 实验结果

太长不看版

扩散模型在图像生成方面表现出了很显著的性能。然而对于生成超高分辨率的图像 (比如 4096 ×4096) 而言,由于其 Memory 也会二次方增加,因此生成的图像的分辨率通常限制在 1024×1024。在这项工作中。作者提出了一种单向块注意力机制,可以在推理过程中自适应地调整显存开销并处理全局依赖关系。在这个模块的基础上,作者使用 DiT 的架构,并逐渐执行上采样,最终开发了一个无限的超分辨率模型 Inf-DiT,能够对各种形状和分辨率的图像进行上采样。综合实验表明,Inf-DiT 在生成超高分辨率图像方面取得了 SOTA 性能。与常用的 UNet 结构相比,Inf-DiT 在生成 4096×4096 图像时可以节省超过5倍显存。

ee78193b3d46ffc94c6be434b343c43b.png
图1:基于 SDXL、DALL-E 3 和真实图像,选择出的 Inf-DiT 超高分辨率上采样示例

本文做了哪些具体的工作

  1. 提出了单向块注意力机制 (Unidirectional Block Attention,UniBA) 算法,在推理过程中将最小显存消耗从  降低到 , 其中  表示边长。该机制还能够通过调整并行生成的块数量、在显存和时间开销之间进行权衡来适应各种显存限制。

  2. 基于这些方法,训练了一个图像上采样扩散模型 Inf-DiT,这是一个 700M 的模型,能够对不同分辨率的和形状图像进行上采样。Inf-DiT 在机器 (HPDV2 和 DIV2K 数据集) 和人工评估中都实现了最先进的性能。

  3. 设计了多种技术来进一步增强局部和全局一致性,并为灵活的文本控制提供 Zero-Shot 的能力。

1 Inf-DiT:Diffusion Transformer 任意分辨率上采样

论文名称:Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer (Arxiv 2024.03)

论文地址:

https://arxiv.org/pdf/2405.04312

项目地址:

https://github.com/THUDM/Inf-DiT

1.1 超高分辨率图像生成问题的挑战:GPU 显存需求

近年来,扩散模型取得了快速发展,显着推动了图像生成和编辑领域的发展。尽管取得了进步,但仍然存在一个关键的限制:现有图像扩散模型生成的图像的分辨率通常被限制在 1024×1024 像素或更低,这对生成超高分辨率图像提出了重大挑战,这在包括复杂的设计项目、广告和海报和墙壁纸的创建等各种实际应用中是必不可少的。

生成高分辨率的常用方法是 Cascaded Generation,它首先生成低分辨率图像,然后应用多个上采样模型逐步提高图像的分辨率。这种方法将高分辨率图像的生成分解为多个子任务。基于前一阶段产生的结果,后期的模型只需要执行局部的生成。在级联结构的基础上,DALL-E2[1]和 Imagen[2]都可以有效地生成 1024×1024 分辨率的图像。

上采样到更高分辨率的图像的最大挑战是关于 GPU 显存需求。例如,如果使用广泛采用的 U-Net 架构,例如 SDXL[3]进行图像推理 (见下图2),可以观察到显存消耗随着分辨率的增加而急剧增加。具体来说,如果生成 4096×4096 分辨率的图像,其包含超过 16 亿个像素需要超过 80GB 的显存,超过了标准 RTX 4090 或 A100 显卡的容量。此外,用于高分辨率图像生成的训练模型的过程加剧了这些需求,因为它需要额外的显存来存储梯度、优化器状态等。LDM[4]通过利用变分自动编码器 (Variational Autoencoder,VAE) 压缩图像并在更小的 Latent Space 中生成图像来减少显存消耗。然而,过高的压缩比会大大降低生成的质量,对显存消耗的减少造成了严重的限制。

0ef5982de494b372982623308413f0b0.png
图2:本文模型和 SDXL 架构之间不同分辨率的推理期间显存使用的比较

1.2 单向块注意力机制

作者观察到生成超高分辨率图像的关键障碍是显存限制。随着图像的分辨率的增加,网络中相应的 hidden states 的大小呈二次方的复杂度扩展。例如,1层中形状为 2048×2048×1280 的单个 hidden state 需要 20GB 的显存,这使得很难生成非常大的图像。如何避免将整个图像的 hidden state 存储在内存中成为关键的问题。

作者的主要想法是将图像  划分为 Blocks , 其中09b84026a46c1680768312ec0db28234.png是块大小,  。当图像被送入网络时, Block 的大小和分辨率可能会改变, 但 Block之间的布局和相对位置关系保持不变。如果有一种方法可以应用顺序批量生成 Blocks,其中每个 Batch 同时生成 Blocks 的子集,则只需要同时在内存中保留少量 Blocks 的隐藏状态,就可以生成超高分辨率图像。

本文的方法单向块注意力 (Unidirectional Block Attention, UniBA) 如下图3所示。对于每个层,每个 Block 直接依赖于3个一阶相邻的 Block:顶部的 Block、左侧和左上角的 Block。例如,如果采用 Diffusion Transformer (DiT) 架构,Block 之间的依赖关系是注意力操作,每个 Block 的 Query 向量与4个 Block 的 Key,Value 向量交互:位于其左上角和自身的3个 Blocks,如图3所示。

e5b0b562407903bc7eebc795f137581b.png
图3:左侧:单向块注意力。在我们的实现中,每个 Block 直接取决于每一层的3个 Blocks:左上角的块、左侧和顶部的 Block;右侧:Inf-DiT 的推理过程。Inf-DiT 根据内存大小每次生成 n×n 个 Block。在这个过程中,只有后续块所依赖的块的 KV-cache 存储在内存中

Transformer 中的 UniBA 过程可以表述为:

e476bc705a4a109454d49f8fec5840fe.png

其中,3e13e1aea3da55951d1b83b05ec89f98.png 是第9e268bfb9cb71dab332adf70af36adf2.png层, 第78b6ff0ca1bf8ba0d7d9165047c29ed4.png行, 第515fe90dfa71218ba6b1a22c0dfd463d.png列的 hidden state,44fcaf089c9556745bc294a64e969239.png是 Block-level 的相对位置编码。

1.3 O(N) 显存消耗的推理过程

尽管本文的方法可以按顺序生成每个 Block,但它与自回归的生成模型不同。在自回归的生成模型中,下一个 Block 取决于前面 Blocks 的最终输出。本文方法可以并行生成任意数量的块。基于这一特性,作者实现了一个简单但有效的推理过程。如图3所示,一次性生成 n×nn×n 个块,从左上角到右下角。在生成一组块后,丢弃不再使用的隐藏状态,并将新生成的 KV-cache 附加到显存中。

可以很容易地证明, 在此过程中保留在显存中的 Block KV-cache 的数量总是  。假设模型在生成单个 Block 时所需的空间为a5d2b1b730cf86fa301271e1ca424bc0.png, 一个 Block 的 KV-cache 的空间为25d2d5cf83ea8ab46688c9dab57300b2.png, 其他基本空间消耗 (例如存储原始输入图像) 为C, 则推理过程的最大空间使用为  。当7e3851af5e15165e762581666a462c14.png远小于b00bb993457b1efc88a92dfe75c93fed.png时, 内存消耗与5e22322bad937a2a022d9f66438a4904.png成正比。

在实际应用中, 尽管对于不同的85df2f5b87e59795ca34bf42122e2aa4.png值, 生成图片的总 FLOPs 是恒定的, 但是受算子初始化时间与显存分配时间的影响, 当15cfa7128e0576b30f2f5a36d3d45ae5.png增加时, 生成时间减少。因此, 最好选择内存限制允许的最大4cb5ede6a44dce548383d2fd930fc571.png

1.4 Inf-DiT 架构

如下图4所示是 Inf-DiT 架构,它基于 DiT[5]。与基于卷积的结构 (如 U-Net[6]) 相比,DiT 仅利用注意力作为 Patch 之间的交互机制,可以方便地实现 UniBA。为了适应 UniBA,提高上采样的性能,作者做了如下几个修改和优化。

5f9191ecd89d05df305ea048316748c1.png
图4:Inf-DiT 架构

模型输入

Inf-DiT首先将输入图像划分为多个不重叠的 Blocks,进一步划分为 Patches。与 DiT 不同,考虑到颜色偏移和细节损失等压缩损失,Inf-DiT 的修补是在 RGB 像素空间中进行的,而不是在 Latent Space。在超分辨率a75b25a9566cce87c0ac21cf79dc5c2c.png次的情况下,Inf-DiT 首先将低分辨率 RGB 图像条件上采样8eb66936cc80b041ef34ded3d8710f51.png倍,然后将其与扩散的噪声输入在特征维数上 Concat 起来,然后将其输入到模型中。

位置编码

最近 LLM 的结果表明,与绝对位置编码相比,相对位置编码在捕获词位置相关性方面更有效。作者参考了 Rotary 位置编码 (RoPE)[7]的设计,它在长上下文生成中表现良好,并将其适配到二维形式的图像生成中。具体来说,作者将隐藏状态的通道分成两半,一个用于编码4a2621279e4150b64c6ece2d4dfe7135.png坐标,另一个用于编码264b7b4543cd60a23b7ac676a049bd06.png坐标,分别使用 RoPE。

作者创建了一个足够大的 Rotary 位置编码表。为了确保训练过程中模型可以看到位置编码表的所有部分, 作者使用随机起点: 对于每个训练图像, 为图像的左上角随机分配一个位置 , 而不是默认的  。

此外, 考虑到同一个 Block 内和不同 Block 之间的交互差异, 作者还引入了 Block-level 的相对位置编码b5e36ed22c48d55ff344b6cf3880cf8a.png, 它根据注意前的相对位置分配不同的 learnable embedding。

1.5 全局和局部一致性

使用 CLIP Image Embedding 针对全局一致性

低分辨率 (LR) 图像中的全局语义信息,如艺术风格和物体材料,在上采样过程中起着至关重要的作用。然而,与文生图像模型相比,上采样模型还有一个额外的任务:理解和分析 LR 图像的语义信息,大大增加了模型的负担。在没有文本数据进行训练时尤其具有挑战性,因为高分辨率图像很少具有高质量的配对文本,这使得模型的这些方面变得困难。

受 DALL-E2[1]的启发,作者利用预训练的 CLIP[8]中的图像编码器从低分辨率图像中提取 Image Embeddinge67ded5d6ec491b2f703cbedd6beabc1.png ,称之为语义输入。由于 CLIP 是在互联网上海量的图像-文本对上训练的,其图像编码器可以有效地从低分辨率图像中提取全局信息。作者将全局语义嵌入添加到 Diffusion Transformer 的 time Embedding 中,并将其输入到每一层,使模型能够直接从高级语义信息中学习。

全局语义嵌入的另一个有趣优势是,使用 CLIP 中的对齐图像-文本 Latent Space,即使本文模型没有在任何图像-文本对上进行训练, 也可以使用文本来指导生成。给定一个正提示bddacdcf30526ae6492f1905e9c35c9a.png 和一个负提示e9467e5c38d64b8010a52bbda49ca3f0.png, 可以更新图像嵌入:

e89fce30849503d9a9ccb7d35dc56e03.png

其中,254a2d4a3f45e4c390003522bf6e2da3.png可以控制指导的强度。在推理过程中, 可以简单地使用9032fbe3585c387f859bbaefacf3c1f6.png 代替2b49b8292856ec56e091bc94420f5f32.png 作为全局语义嵌入来进行控制。例如, 为了获得更清晰的结果集,  "clear" 和  " blur" 有时会有所帮助。

使用 Nearby LR Cross Attention 针对局部一致性

尽管将 LR 图像与噪声输入 Concat 起来已经为模型学习 LR 和 HR 图像之间的局部对应关系提供了良好的归纳偏差,但仍然可能存在连续性的问题。原因是,对于给定的 LR Block,有几种上采样的可能性,这需要与附近的几个 LR Block 一起分析以选择一种解决方案。假设上采样仅基于其左侧的 LR Block 执行,它可能会选择一个与右侧和下方 LR Block 冲突的 HR 生成解决方案。然后,当将 LR Block 上采样到右侧时,如果模型认为符合其对应的 LR Block 比与左侧的 Block 连续更重要,则会生成一个与先前块不连续的 HR Block。一个简单的解决方案是将整个 LR 图像输入到每个 Block,但当 LR 图像的分辨率也很大时,它的成本太高。

为了解决这个问题,作者引入了 Nearby LR Cross-Attention。在第一层中,每个 Block 对周围的 3×3 LR Block 进行 Cross-Attention,以捕获附近的 LR 信息。实验结果表明,这种方法显着减少了生成不连续图像的概率。值得注意的是,这个操作不会改变我们的推理过程,因为在生成之前知道整个 LR 图像。

1.6 实验结果

训练细节

本文的数据集包括 LAION-5B[9]的一个子集,分辨率高于 1024×1024,美学得分高于 5 的 100000 来自互联网的分辨率墙纸。在训练过程中,作者使用 512×512 分辨率的固定大小的 Image crop。由于上采样只能使用局部信息进行,因此在推理过程中可以直接用于更高的分辨率,这对于大多数生成模型来说并不容易。

数据准备

由于扩散模型生成的图像通常包含残余噪声和各种细节不准确,因此增强上采样模型的鲁棒性以解决这些问题变得至关重要。作者采用类似于 Real-ESRGAN[10]的方法对训练数据中的低分辨率输入图像执行各种退化。

在处理分辨率高于 512 的图像时,有两种替代方法:一种是直接执行随机裁剪,另一种是在执行随机裁剪之前将较短的边调整为 512。虽然直接裁剪方法在高分辨率图像中保留了高频特征,但调整大小后裁剪方法避免了频繁裁剪单个颜色背景的区域,不利于模型的收敛。因此在实践中,作者从这两种处理方法中随机选择裁剪训练图像。

作者将 Block Size 设置为 128,Patch Size 设置为 4,即每张图片被分成 4×4 Blocks,每个 Block 被分成 32×32 Patches。作者使用 EDM[11]框架训练,并将上采样设置为4倍。由于上采样任务更关注图像的高频细节,我们将训练噪声分布的均值和标准差调整为 -1.0 和 1.4。为了解决训练期间的溢出问题,作者采用了 BF16 格式。采用的 CLIP 模型是在 Datacomp 数据集[12]上预训练的 ViT-L/16。由于 CLIP 只能处理分辨率为 224×224 的图像,作者首先将 LR 图像的大小调整为 224×224,然后将它们输入到 CLIP 中。

机器评测

作者对 Inf-DiT 与超高分辨率图像生成任务的最新方法进行了定量比较,Baseline 包含两类高分辨率生成方法:

1) 直接高分辨率图像生成,包括 SDXL、MultiDiffusion[13]、ScaleCrafter[14]。

2) 基于超分辨率技术的高分辨率图像生成,包括 BSRGAN[15]、DemoFusion[16]。

使用 FID[17]来评估超高分辨率生成的质量。为了进一步验证我们模型的超分辨率能力,作者还在经典的超分辨率任务上将其与著名的超分辨率模型进行了基准测试。

超高分辨率生成结果

作者使用 HPDV2 的测试集进行评估。它包含 3200 个 Prompt,分为4类:"Animation", "Concept-art", "Painting", 和 "Photo"。这允许对各种域和样式的模型生成能力进行全面的评估。作者在 2048 和 4096 两个分辨率上面进行测试。对于基于超分辨率的模型,作者首先使用 SDXL 生成 1024×1024 分辨率的图像并在没有文本的情况下对其进行上采样。作者使用 BSRGAN 的 2 倍和 4 倍版本分别生成 2048×2048 和 4096×4096 分辨率的图片。虽然 Inf-DiT 是在上采样 4× 的设置下训练的,但作者发现它可以在较低的上采样倍数下很好地泛化。对于 2048×2048 分辨率,作者直接将 LR 图像的大小从 1024×1024 调整到 2048×2048,并将其与噪声输入拼接起来。

如下图5所示的实验结果显示,本文模型在平均得分上超过了所有竞争对手。这展示了本文模型生成高分辨率细节和全局信息的能力。唯一的例外是 4096×4096 分辨率上的 FID 指标,略微落后于 BSRGAN。本文模型可以应用于所有生成模型,不仅仅是 SDXL。

65daf175ac85d6f355803a6788cd537f.png
图5:HPDV2 数据集上超高分辨率生成方法的定量比较结果
ab832dcc45eca9528b7c6a3995b3cb73.png
图6:2048×2048 分辨率下对不同方法的定性比较
7ca9f06cdb73d52b893754b164b0df0e.png
图7:不同方法在 4096×4096 分辨率下的定性比较

超分辨率实验结果

除了生成高分辨率图像的能力外,Inf-DiT 也可以用作经典的超分辨率模型。作者对 DIV2k 验证集进行评估,该数据集包含不同场景下多个真实世界的高分辨率图像。作者将图像退化固定为 4× 下采样的双三次插值。在与固定分辨率模型 LDM 和 StableSR 进行比较之前,作者从高分辨率图像中心裁剪特定的小块作为 ground truth。在整个过程中,作者使用感知 (FID, FIDcrop) 和保真度 (PSNR, SSIM) 指标来确保详细和全面的评估。

实验结果如图8所示,本文模型在所有指标上实现了最先进的性能。这意味着,作为超分辨率模型,Inf-DiT 不仅擅长在任意尺度上执行超分辨率,还擅长在恢复与原始图像非常相似的结果的同时最佳地保留全局和局部信息。

9ea4d8ba71319aa3a69a17eabd73de60.png
图8:DIV2K 数据集与最先进的超分辨率方法的比较结果

人类评测结果

为了从人类的角度更准确地反映其生成质量,作者进行了人工评估。作者比较了4个模型,对每个模型随机选择十个比较集,每个比较集包含来自四个模型的输出,最终总共有 40 个数据。为了确保公平,作者在每个比较集中随机化模型输出序列。人类评估者被要求根据3个标准评估模型:细节真实性、全局连贯性和与原始低分辨率输入的一致性。每个评估者平均接收 20 组图像。在每个集合中,评估者需要根据3个标准将四个模型生成的图像从最高排名排名最低。

最终收集了 3,600 次比较。如图9所示,本文模型在所有3个标准上都优于其他3种方法。特别值得注意的是,其他3个模型在3个评估标准中至少有一个的排名相对较低,而 Inf-DiT 在所有3个标准上都取得了最高分:细节真实性、全局连贯性和与低分辨率输入的一致性。这表明本文模型是唯一能够同时在高分辨率生成和超分辨率任务中表现出色的模型。

20b63380700d8c236ae405beb4072897.png
图9:人类评估结果

迭代上采样

由于本文的模型可以对任意分辨率的图像进行上采样,因此测试模型是否可以迭代地对自身生成的图像进行上采样是很自然的想法。在这项研究中,作者尝试通过3次迭代上采样从 32×32 分辨率图像上采样 64 倍之后生成 2048×2048 分辨率图像。图 10 展示了两种样本。在第1个样本中,模型在上采样3个阶段后成功地生成了高分辨率图像。它在不同的分辨率上采样中生成不同频率的细节:人脸的轮廓、眼球的形状和个人睫毛。然而,模型很难纠正早期阶段产生的不准确,从而导致错误的积累。在第2个样本中,作者演示了这个问题的一个示例。我们将此问题留给未来的工作。

2531df058b14f68cc09874503db74cdc.png
图10:迭代上采样结果。上:Inf-DiT 可以多次上采样自己生成的图像,并在相应分辨率下生成不同频率的细节;下:在低分辨率 128×128 时未能准确生成之后,后续很难纠正错误

参考

  1. ^abHigh-resolution image synthesis with latent diffusion models

  2. ^Photorealistic textto-image diffusion models with deep language understanding

  3. ^Sdxl: Improving latent diffusion models for high-resolution image synthesis

  4. ^High-resolution image synthesis with latent diffusion models

  5. ^Scalable Diffusion Models with Transformers

  6. ^U-Net: Convolutional Networks for Biomedical Image Segmentation

  7. ^RoFormer: Enhanced Transformer with Rotary Position Embedding

  8. ^Learning Transferable Visual Models From Natural Language Supervision

  9. ^LAION-5B: An open large-scale dataset for training next generation image-text models

  10. ^Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data

  11. ^Elucidating the Design Space of Diffusion-Based Generative Models

  12. ^https://doi.org/10.5281/zenodo.5143773

  13. ^Multidiffusion: Fusing diffusion paths for controlled image generation

  14. ^Scalecrafter: Tuning-free higher-resolution visual generation with diffusion models

  15. ^Designing a practical degradation model for deep blind image super-resolution

  16. ^Demofusion: Democratising high-resolution image generation with no $$$

  17. ^GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium

投稿作者为『自动驾驶之心知识星球』特邀嘉宾,欢迎加入交流!

① 全网独家视频课程

BEV感知、BEV模型部署、BEV目标跟踪、毫米波雷达视觉融合多传感器标定多传感器融合多模态3D目标检测车道线检测轨迹预测在线高精地图世界模型点云3D目标检测目标跟踪Occupancy、cuda与TensorRT模型部署大模型与自动驾驶Nerf语义分割自动驾驶仿真、传感器部署、决策规划、轨迹预测等多个方向学习视频(扫码即可学习

f8de02b3653c8e56430454bc3ca94ab0.png

网页端官网:www.zdjszx.com

② 国内首个自动驾驶学习社区

国内最大最专业,近3000人的交流社区,已得到大多数自动驾驶公司的认可!涉及30+自动驾驶技术栈学习路线,从0到一带你入门自动驾驶感知2D/3D检测、语义分割、车道线、BEV感知、Occupancy、多传感器融合、多传感器标定、目标跟踪)、自动驾驶定位建图SLAM、高精地图、局部在线地图)、自动驾驶规划控制/轨迹预测等领域技术方案大模型、端到端等,更有行业动态和岗位发布!欢迎扫描下方二维码,加入自动驾驶之心知识星球,这是一个真正有干货的地方,与领域大佬交流入门、学习、工作、跳槽上的各类难题,日常分享论文+代码+视频

5e8bd3149fe0db57d42b766cec12c38c.png

③【自动驾驶之心】技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦感知、定位、融合、规控、标定、端到端、仿真、产品经理、自动驾驶开发、自动标注与数据闭环多个方向,目前近60+技术交流群,欢迎加入!扫码添加汽车人助理微信邀请入群,备注:学校/公司+方向+昵称(快速入群方式)

2b11babaa3cc19f5c1db53f7bee295f1.jpeg

④【自动驾驶之心】全平台矩阵

f113e2af22ba354d95036f204669380c.png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值