Packing Input Frame Context in Next-Frame Prediction Models for Video Generation
贡献
- FramePack压缩输入帧,使Transformer的上下文长度成为固定数量,而不管视频输入长度如何。这样能够使用视频扩散处理大量的帧(因为无论多少帧都可以压缩到固定大小),其计算bottleneck类似于图像扩散。这也使得训练视频的批量大小显著增加 (批量大小与图像扩散训练相当)。
- 提出一种抗漂移采样方法,早期设立最后一帧,然后反转时间顺序生成,以避免暴露偏差 (迭代后的错误累积)。
视频生成中的遗忘和漂移
遗忘和漂移是视频生成中下一帧(或下一帧片段)预测模型的两个最关键问题,其中,“遗忘”是指随着模型努力记住早期内容并保持一致的时间依赖性而记忆逐渐消退;“漂移”是指随着时间推移误差累积而导致视觉质量迭代下降(也称为曝光偏差)。
也就是说想记住更久之前的帧的影响,但是随着时间越往后,生成偏差越来越严重。
当试图同时解决遗忘和漂移问题时,会出现一个基本的困境:
- 任何通过增强记忆来减轻遗忘的方法也可能使错误积累/传播更快,从而加剧漂移;
- 任何通过中断错误传播并削弱时间依赖性(例如掩蔽或重新添加噪声的历史)来减少漂移的方法也可能会加重遗忘。
这种本质上的权衡阻碍了下一帧预测模型的可扩展性。
解决这两个问题有很简单的思路:
- 对于遗忘问题,那我就尽量编码尽可能多的帧,但是这样会导致Transformer的复杂度很高,无法计算。
- 对于漂移问题,自回归任务中任何一帧的初始错误都会随着回归过程在后续帧中传播并累积,导致视觉质量下降,就要尽量减少误差的传播和累积。
作者的解决思路
FramePack结构通过根据输入帧的相对重要性(根据帧序列的远近)进行压缩来解决遗忘问题,确保Transformer的总上下文长度收敛到固定的上限,而不受视频持续时间的影响。这使得模型能够在不增加计算瓶颈的情况下编码更多的帧,从而促进抗遗忘。
此外,FramePack提出了抗漂移抽样方法,打破了因果预测链,并纳入了双向上下文。这些方法包括在填充中间内容之前生成两端点帧(初始帧和结尾帧),以及一种倒置的时间采样方法,其中帧以相反的顺序生成,每个帧都试图接近一个已知的高质量帧。
帧的重要性排序
典型的视频生成模型都可以表示为如下图所示,根据先前的一些帧来预测之后的一帧或者多帧:
问题是选择多少先前的帧才能有效?
作者观察到,在预测下一帧时,输入帧具有不同的重要性,并且可以根据其重要性对输入帧进行优先级排序。在不失一般性的前提下,考虑一个简单的案例:时间接近度反映了重要性,与预测目标更近的帧可能更具相关性。我们将所有帧列举出来,其中
F
0
F_0
F0是最重要的(例如最近),
F
T
−
1
F_{T-1}
FT−1是最重要的(例如最旧)。
然后作者定义了一个长度函数,该函数在VAE编码和Transformer拼接后确定每个帧的上下文长度,并对不重要的帧应用渐进压缩,这里的λ本文取的是2,λ是一个压缩参数,通过操纵输入层中的Transformer的patchify核大小来实现帧间压缩(例如,λ=2,i=5表示一个内核尺寸,其所有维度的乘积等于
2
5
=
32
2^5=32
25=32,如三维内核2×4×4或8×2×2等):
然后总上下文长度遵循几何级数,前半部分是输出帧,后半部分是输入帧:
当T趋近于无穷时,总上下文长度就是:
由于DiT中的Patch操作都是三维的,作者用(pf,ph,pw)表示帧数、高度和宽度上的步长。相同的压缩率可以通过多个可能的核尺寸来实现,例如,64的压缩率可以由(1,8,8),(4,4,4),(16,2,2),(64,1,1)等获得。这将导致不同的FramePack压缩计划。
观察到深度神经网络在不同压缩率下的特征表现出显著的差异。经验证据表明,在多个压缩率下对不同的输入预测使用独立参数有助于稳定学习。作者将最常用的输入压缩内核分配为独立的神经网络层:( 2,4,4)、( 4,8,8) 和( 8,16,16)。对于更高的压缩 (e.g. (16,32,32)) 处,我们首先进行下采样 (e.g. 用2 × 2 × 2),然后使用最大的核 (8,16,16)。这使我们能够处理所有的压缩率。在训练这些新的输入投影层时,我们通过从预训练的patchifying投影 (例如,HunyuanVideo/Wan的 (2,4,4) 投影插值来初始化它们的权重)。
尾部处理
在理论上,FramePack可以处理具有固定不变上下文长度的任意长度的视频,但当输入帧长度变得非常大时会出现实际考虑。在尾部区域,帧可能会低于最小单位大小 (例如,单个潜在像素)。作者讨论了处理尾部的3个选项 :
- 简单地删除尾部;
- 允许每个尾部框架增加一个潜在像素的上下文长度;
- 将全局平均池化应用于所有尾帧,并使用最大内核处理它们。
在测试中,这些选项之间的视觉差异相对可以忽略不计。我们注意到尾部指的是最不重要的帧,并不总是指最早的帧 (在某些情况下,我们可以为旧帧分配更高的重要性)。
RoPE对齐
当使用不同的压缩核编码输入时,不同上下文长度需要RoPE(旋转位置嵌入)对齐。RoPE为所有通道中的每个标记位置生成具有实部和虚部的复数,我们称之为“相位”。RoPE通常按通道将相位乘以神经网络特征。为了匹配压缩后的RoPE编码,我们将直接下采样(使用平均池化)RoPE相位来匹配压缩核。
# 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
[rope定义](https://github.com/lllyasviel/FramePack/blob/b680f72df8b87340b06e701dc1a59a6687ebd962/diffusers_helper/models/hunyuan_video_packed.py#L763)
def process_input_hidden_states(
self,
latents, latent_indices=None,
clean_latents=None, clean_latent_indices=None,
clean_latents_2x=None, clean_latent_2x_indices=None,
clean_latents_4x=None, clean_latent_4x_indices=None
):
hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
B, C, T, H, W = hidden_states.shape
if latent_indices is None:
latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
if clean_latents is not None and clean_latent_indices is not None:
clean_latents = clean_latents.to(hidden_states)
clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
clean_latents = clean_latents.flatten(2).transpose(1, 2)
clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
if clean_latents_2x is not None and clean_latent_2x_indices is not None:
clean_latents_2x = clean_latents_2x.to(hidden_states)
clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
if clean_latents_4x is not None and clean_latent_4x_indices is not None:
clean_latents_4x = clean_latents_4x.to(hidden_states)
clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
return hidden_states, rope_freqs
一些其他的策略
抗漂移采样
漂移是下一帧预测模型中的一个常见问题,随着视频长度的增加,视觉质量会下降。虽然根本原因仍然是一个开放的研究问题,但我们观察到漂移只发生在因果采样中(即当模型只能访问过去帧时)。我们表明提供对将来帧的访问权限(即使是一个单一的未来帧)将消除漂移。我们指出双向上下文,而不是严格的因果依赖性,可能是保持视频质量的根本因素。
图(a)中显示的vanilla采样方法,即迭代预测未来帧,可以修改为图(b),其中第一轮同时生成开始和结束部分,而后续迭代填充这些锚点之间的间隙。这种双向方法防止漂移,因为结束帧在第一轮建立,并且所有未来的帧都试图近似它们。
我们通过将图(b)中的采样顺序反转为图(c),讨论了一个重要的变体。这种方法对于图像到视频的生成是有效的,因为它可以将用户输入作为高质量的第一帧,并不断细化生成以逼近用户帧(这与图(b)不近似第一帧的情况不同),从而产生整体上高品质的视频。 这句话的意思就是图(c)这种采样方式是逐步靠近高质量的第一帧的,这样生成的新帧肯定是贴切的。
图(a,b,c)中的所有三种方法都可以生成任意长度的视频。方法(a)通过直接迭代生成来实现这一点,在方法(b)和(c)中,我们可以随着我们生成的帧接近它们而动态地将结尾部分(或生成的帧)移动到更远的距离。或者在实践中,通常在第一次迭代中设置足够大的时间范围(例如1分钟),这通常可以满足实际需求。