图片视频联合编解码器 MAGVIT-v2

资源

github(非官方实现):https://github.com/lucidrains/magvit2-pytorch
paper: https://arxiv.org/pdf/2310.05737.pdf


2024年年初因为sora的爆火,大家开始对视频生成相关的技术进行广泛的讨论和研究,sora这类视频生成的模型离不开一个重要的组件:一个能高比率压缩且高保真的视频编码器。高比率压缩可以让后续的Transformer模型学习更为容易,高保真能保证最后生成的视频质量。MAGVIT-V2是目前表现SOTA的视频编解码器,视频生成模型VideoPoet就使用它作为视频编解码器。

MAGVIT-V2的创新点有两个:

  1. 使用3D因果卷积(Casual 3D CNN)实现图片和视频的联合编解码器。(以前基于3D CNN的方法因为3D CNN的特性只能对视频建模)
    在这里插入图片描述

  2. 使用LFQ(lookup-free quantizer)将codebook增大到 2 18 2^{18} 218,也即词汇表的长度为 262144。作者进行了实验,使用VQ的方式随着codebook的增大,重建的能力虽然变强了,生成能力反而变差(想象codebook大到能容纳所有的取值,那就丢失了生成能力)。而LFQ随着codebook的增大重建和生成能力都会得到提升。同时LFQ不需要查表的过程,提升了计算效率。
    在这里插入图片描述

文中LFQ的直观理解:
MAGVIT-v2使用的是一种 LFQ 的变体,它假设码本维度独立和 latent 变量为二进制。既然LFQ无需查表,那于是可以通过二进制形式的 latent 变量转换为codebook中十进制的对应编号。
比如在VQ中过程是:
假设某一embedding为 q = [ 0.1 , 0.2 , − 0.1 , 0.8 , 0.1 ] q = [0.1,0.2,-0.1,0.8,0.1] q=[0.1,0.2,0.1,0.8,0.1] 和codebook中的第8个embedding [ 0.2 , 0.2 , − 0.1 , 0.7 , 0.1 ] [0.2,0.2,-0.1,0.7,0.1] [0.2,0.2,0.1,0.7,0.1]相似度最高,那么量化后的结果就是 q ′ = [ 0.2 , 0.2 , − 0.1 , 0.7 , 0.1 ] , t o k e n _ i d = 8 q'=[0.2,0.2,-0.1,0.7,0.1], token\_id=8 q=[0.2,0.2,0.1,0.7,0.1],token_id=8

在LFQ中过程是:
某一embedding为 q = [ 0.1 , 0.2 , − 0.1 , 0.8 , 0.1 ] q = [0.1,0.2,-0.1,0.8,0.1] q=[0.1,0.2,0.1,0.8,0.1],经过符号函数比如torch.sign()将大于0的值变为1,小于等于0的值为-1。 s i g n _ q = [ 1 , 1 , − 1 , 1 , 1 ] sign\_q = [1,1,-1,1,1] sign_q=[1,1,1,1,1],sign_q>0的位置以二进制表示,并且对应的十进制为16+8+2+1= 27。因此无需查表可以得到q的量化结果为 q ′ = [ 1 , 1 , − 1 , 1 , 1 ] , t o k e n _ i d = 27 q'=[1,1,-1,1,1], token\_id=27 q=[1,1,1,1,1],token_id=27

代码表示: 参考自https://zhuanlan.zhihu.com/p/679032979

import torch

class LookupFreeQuantizer:
    def __init__(self, vocab_size: int=None):
        """
        初始化方法
        :param vocab_size: 词汇表大小,表示要将实数值张量映射到的整数范围。如果未提供,表示不限定词汇表大小。
        """
        self.vocab_size = vocab_size

    def sign(self, z: torch.Tensor):
        """
        将张量中的每个元素转换为其符号
        :param z: 包含实数值的张量
        :return: 符号化的张量
        """
        q_z = torch.sign(z)
        q_z[q_z == 0] = -1  # 将零元素转换为-1
        return q_z

    def token_index(self, q_z: torch.Tensor):
        """
        将符号化值张量转换为整数值张量
        :param q_z: 符号化的张量
        :return: 整数值张量
        """
        indices = (torch.arange(q_z.size(-1), dtype=torch.float32)).to(q_z.device)
        tokens = torch.sum(2**indices * (q_z > 0).float(), dim=-1)
        return tokens

    def quantize(self, z: torch.Tensor):
        """
        对实数值张量进行量化
        :param z: 包含实数值的张量
        :return: 符号化的值和整数值张量
        """
        if self.vocab_size is not None:
            assert z.size(-1) == torch.log2(self.vocab_size)

        q_z = self.sign(z)
        index = self.token_index(q_z)
        return q_z, index.int()
  • 30
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值