【代码解读】阿里最新开源视频生成模型 Wan 2.1 实现解析

昨晚阿里巴巴开源了最新视频生成模型的代码和权重,从给出的 demo 效果来看还是非常惊艳的。这个模型目前也是在 VBench 榜单上排到了第一名,超越了 Sora 以及 HunyuanVideo 等一系列知名方法。
截至写文章时的 VBench 榜单
从官方给出的方法架构图来说,Wan 2.1 并没有使用 MMDiT 的架构,而是基于普通的 DiT 架构,而文本条件则是通过 Cross Attention 实现注入。在文本编码器方面,Wan 2.1 采用了支持多语言的 UMT5 作为编码器,因此 prompt 部分或许能够原生支持中文输入。图中的 Wan-Encoder 和 Wan-Decoder 实际上就是视频生成模型常用的 3D Causal VAE,根据官方的说法,其支持无损时序信息编解码任意时长1080P视频。在时间编码方面,模型的所有 block 采用了统一的时间步编码器,并采取了类似 AdaLN 的方式将时间步编码进行注入。
Wan 2.1 模型架构
Wan 2.1 公布了不同尺寸的多个变体,小型的为 1.3B,想必是为了支持消费级显卡推理推出的一款模型;大型的为 14B,超过了 HunyuanVideo 的尺度,并且支持 720P 分辨率视频的生成。从表中的信息来看,不仅支持文生视频,同时也能够支持图生视频。

模型支持 480P 分辨率支持 720P 分辨率
T2V-14B支持支持
I2V-14B-720P不支持支持
I2V-14B-480P支持不支持
T2V-1.3B支持不支持

同时官方也已经给出了一些定量指标,目前看到的生成质量指标是由人工评测得到的,所以暂时先不分析。个人感觉比较重要的信息是这张图里的推理成本。从表中可以看出,1.3B 模型的峰值显存占用仅为 8 GB,且在单张消费级显卡上推理约 4 分钟即可的得到一段视频(而且这个结果是将 T5 模型卸载到 CPU 上得到的,所以如果把文本提前做离线 embedding,这个性能应当还有进一步提升),还是很可观的。不过 14B 模型的推理成本就比较高了,在单卡上的显存占用已经接近 80 GB,推理时间也来到了几千秒的数量级。
Wan 2.1 的推理成本

代码实现分析

首先可以看到的是,和其他的方法一样,Wan 2.1 也使用了 Classifier-Free Guidance(代码链接):

noise_pred_cond = self.model(
    latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
    latent_model_input, t=timestep, **arg_null)[0]

noise_pred = noise_pred_uncond + guide_scale * (
    noise_pred_cond - noise_pred_uncond)

对于图生视频任务,模型会使用 CLIP Vision Encoder 将图像进行编码作为 latents 中的第一帧,其余部分填充零,且加入一个 mask 通道(类似 inpainting 的做法,代码链接):

self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
    self.clip.model.cpu()

y = self.vae.encode([
    torch.concat([
        torch.nn.functional.interpolate(
            img[None].cpu(), size=(h, w), mode='bicubic').transpose(
                0, 1),
        torch.zeros(3, 80, h, w)
    ],
                    dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])

除了在 latents 上的调整,图生视频还会将图像的 CLIP 特征再次进行 embedding,并在 Cross Attention 时与文本图像拼接后共同作为条件进行生成(代码链接):

if clip_fea is not None:
    context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
    context = torch.concat([context_clip, context], dim=1)

进入模型内部,可以看到一个现象是模型的输入并不是 batched tensor,而是一个 tensor 的列表,相当于把同一个批次拆分成了多个单个视频。在推理时也是遍历整个列表,可能因为模型的推理显存比较高,通过把批次拆开来节省显存。以 Patch Embedding 为例(代码链接):

x = [self.patch_embedding(u.unsqueeze(0)) for u in x]

对于模型的每个 block,其内部由一组 self attention 与一组 cross attention 组成,并且都按照 DiT 的方式进行了 modulation 操作(代码链接):

# self-attention
y = self.self_attn(
    self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
    freqs)
with amp.autocast(dtype=torch.float32):
    x = x + y * e[2]

# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
    x = x + self.cross_attn(self.norm3(x), context, context_lens)
    y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
    with amp.autocast(dtype=torch.float32):
        x = x + y * e[5]
    return x

x = cross_attn_ffn(x, context, context_lens, e)

在图生视频的 attention 中,将文本 token 与图像 token 进行了拆分,分别与 latent 计算 cross attention,然后再将两组结果相加得到最后的交叉注意力结果(代码链接):

def forward(self, x, context, context_lens):
    r"""
    Args:
        x(Tensor): Shape [B, L1, C]
        context(Tensor): Shape [B, L2, C]
        context_lens(Tensor): Shape [B]
    """
    context_img = context[:, :257]
    context = context[:, 257:]
    b, n, d = x.size(0), self.num_heads, self.head_dim

    # compute query, key, value
    q = self.norm_q(self.q(x)).view(b, -1, n, d)
    k = self.norm_k(self.k(context)).view(b, -1, n, d)
    v = self.v(context).view(b, -1, n, d)
    k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
    v_img = self.v_img(context_img).view(b, -1, n, d)
    img_x = flash_attention(q, k_img, v_img, k_lens=None)
    # compute attention
    x = flash_attention(q, k, v, k_lens=context_lens)

    # output
    x = x.flatten(2)
    img_x = img_x.flatten(2)
    x = x + img_x
    x = self.o(x)
    return x

在计算 self-attention 时,也使用了 RoPE(代码链接):

x = flash_attention(
    q=rope_apply(q, grid_sizes, freqs),
    k=rope_apply(k, grid_sizes, freqs),
    v=v,
    k_lens=seq_lens,
    window_size=self.window_size)

目前来说从代码里能够看到的比较有用的信息就是这些,由于具体的 report 还没有放出来,所以关于数据的细节目前不太清楚(据说用了 1.5B 视频数据和 10B 图像数据),也期待一下技术报告早日公布。

### Wan2.1 本地部署教程和配置指南 #### 准备工作 为了成功部署WAN优化解决方案,确保环境满足最低硬件和软件需求。准备阶段涉及确认网络设备兼容性和操作系统版本支持。 #### 安装依赖项 安装必要的库和支持工具对于顺利设置至关重要。具体命令取决于所使用的Linux发行版: ```bash sudo apt-get update && sudo apt-get install -y build-essential libssl-dev zlib1g-dev \ libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev \ libncursesw5-dev xz-utils tk-dev libffi-dev liblzma-dev python-openssl git ``` 上述命令适用于基于Debian的系统[^4]。 #### 下载并解压Wan2.1源码包 获取官方发布的最新稳定版本压缩包,并将其放置于目标机器上适当位置后进行解压操作: ```bash wget https://example.com/path/to/wan2_1.tar.gz tar zxvf wan2_1.tar.gz cd wan2.1/ ``` 请注意替换下载链接为实际地址。 #### 编译与安装 按照项目README.md内的指示完成编译过程。一般情况下,此步骤包括但不限于执行configure脚本、make以及最终的make install指令来构建应用程序及其关联组件。 ```bash ./configure --prefix=/usr/local/wan2.1 make sudo make install ``` #### 配置服务参数 编辑wan2.1的主要配置文件`/etc/wan2.1.conf`以适应特定应用场景的需求。重点考虑以下几个方面: - **监听端口**:指定用于接收客户端请求的服务端口号。 - **最大并发连接数**:根据预期负载设定合理的上限值。 - **日志级别**:选择合适的调试信息记录等级以便后续排查问题。 #### 启动服务 利用systemd管理单元控制系统服务的状态变化。首次启动前建议先尝试手动方式验证基本功能正常运作后再加入开机自启列表中。 ```bash sudo systemctl start wan2.1.service sudo systemctl enable wan2.1.service ``` #### 测试连通性 最后一步是对新搭建好的平台进行全面的功能测试,确保各项特性均能按预期发挥作用。可以通过telnet或其他类似手段检查远程接入能力;同时借助iperf等工具评估带宽表现及延迟指标。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值