论文笔记:Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer

论文链接:[2405.04312] Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer (arxiv.org)

论文代码:THUDM/Inf-DiT: Official implementation of Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer (github.com)

现有图像扩散模型生成的图像的分辨率通常被限制在 1024 ×1024 像素或更低,在生成超高分辨率图像(例如 4096 ×4096)时内存会二次增加,

上采样到更高分辨率的图像的最大挑战是显着的 GPU 内存需求。另外一个问题是如果要将图像完整的输入模型中,会占用的空间。

因此本文提出了一种单向块注意(UniBA)算法,该算法可以显著降低从O(N2)到O(N)的生成空间复杂度,大大提高了最高的可用分辨率。

Methodology

Unidirectional Block Attention (UniBA)

在UNet、DiT等模型中,块之间的依赖关系是双向的,即在计算时必须同时生成图像中的所有块。为了节省块隐藏状态的内存,我们希望设计一种算法,使其允

许同一图像中的块被分成几批来生成,每批只需要同时生成一部分块,并按批次顺序生成。

主要思想是将图片划分为块,其中B为块的大小。并提出了如下图所示的注意力实现:

左图:单向块注意力中,每个块直接取决于自身层的三个块:左上角的块、左侧和上面的块。

右图:Inf-DiT 的推理过程。Inf-DiT 根据内存大小每次生成 n × n的block。在这个过程中,只有后续块所依赖的块的KV-cache存储在内存中。

Inf-DiT 架构中,块之间的依赖关系是注意力操作。且transformer中单向块注意力可以计算如下:

表示第n层i行j列的块的隐藏状态,为块间相对位置编码。

虽然该方法每一个block的计算依赖的范围变小了,但是由于特征逐层传递,还是可以捕捉到长距离的信息;

在上图中,随着block计算的向前推进,不断有block的hidden states的值被丢弃。即可空间复杂度由原来的变为

Basic Model Architecture

Inf-DiT 的架构使用了与DiT类似的主干,它将Vision Transformer (ViT)应用于扩散模型,与基于卷积的体系结构(如UNet)相比,DiT仅利用注意力作为patch之间的

交互机制,可以方便地实现单向块注意。为了适应单向块注意,提高上采样的性能,我们做了如下几个修改和优化。

Model input

考虑到颜色偏移和细节损失等压缩产生的损失,Inf-DiT 的重建是在 RGB 像素空间中进行的,而不是潜在空间。在超分为f倍时,首先将低分辨率RGB图像上采样f倍,然后将其与扩散的噪声输入在特征维数上连接起来,然后将其输入到模型中。

Position Encoding

参考RoPE旋转位置编码。首先创建一个足够大的位置编码表,使用随机起点:对于每个训练图像,为图像的左上角随机分配一个位置 (x, y),而不是默认的 (0,0)。此外,考虑到同一块内和不同块之间的交互差异,还引入了块级相对位置编码,它根据注意前的相对位置分配不同的可学习嵌入。

Global and Local Consistency

Global Consistency with CLIP Image Embedding

利用预训练的CLIP中的图像编码器从低分辨率图像中提取图像嵌入,称之为语义输入。由于CLIP是在互联网上海量的图像-文本对上训练的,其图像编码器可以有效地从低分辨率图像中提取全局信息。将全局语义嵌入添加到DiT的时间嵌入中,并将其输入到每一层,使模型能够直接从高级语义信息中学习。

使用 CLIP 中的图像-文本潜在空间,即使模型没有在任何图像-文本对上进行训练,也可以使用文本来指导生成的方向。

给定一个正提示和一个负提示,就可以更新图像嵌入:

α用于控制语义的引导强度。在推理过程中,我们可以简单地使用

代替 作为全局语义嵌入来进行控制。

Local Consistency with Nearby LR Cross Attention

模型学习 LR 和 HR 图像之间的局部对应关系时仍然可能存在连续性问题。为了解决这个问题,引入了 Nearby LR Cross Attention。在transformer的第一层中,

每个块对周围的3 × 3 LR块进行交叉注意,以捕获附近的LR信息。实验表明,这种方法显着减少了生成不连续图像的概率。

Experiments

HPDV2数据集下超高分辨率的定量实验:

表现了模型生成高分辨率细节和协调全局信息的能力。虽然在4096X4096下的FID值略小于BSRGAN,但FIDcrop 是高分辨率特征的更有代表性的指标

FIDcrop是从高分辨率图像中随机抽取299 × 299个patch进行FID评估,不会像FID一样忽略了高分辨率的细节,因为FID的原始实现需要在特征提取前将输入图像

下采样到299 × 299的分辨率

下表是在DIV2K数据集下的超分定量实验:

Ablation Study

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值