资源
github(非官方实现):https://github.com/lucidrains/magvit2-pytorch
paper: https://arxiv.org/pdf/2310.05737.pdf
2024年年初因为sora的爆火,大家开始对视频生成相关的技术进行广泛的讨论和研究,sora这类视频生成的模型离不开一个重要的组件:一个能高比率压缩且高保真的视频编码器。高比率压缩可以让后续的Transformer模型学习更为容易,高保真能保证最后生成的视频质量。MAGVIT-V2是目前表现SOTA的视频编解码器,视频生成模型VideoPoet就使用它作为视频编解码器。
MAGVIT-V2的创新点有两个:
-
使用3D因果卷积(Casual 3D CNN)实现图片和视频的联合编解码器。(以前基于3D CNN的方法因为3D CNN的特性只能对视频建模)
-
使用LFQ(lookup-free quantizer)将codebook增大到 2 18 2^{18} 218,也即词汇表的长度为 262144。作者进行了实验,使用VQ的方式随着codebook的增大,重建的能力虽然变强了,生成能力反而变差(想象codebook大到能容纳所有的取值,那就丢失了生成能力)。而LFQ随着codebook的增大重建和生成能力都会得到提升。同时LFQ不需要查表的过程,提升了计算效率。
文中LFQ的直观理解:
MAGVIT-v2使用的是一种 LFQ 的变体,它假设码本维度独立和 latent 变量为二进制。既然LFQ无需查表,那于是可以通过二进制形式的 latent 变量转换为codebook中十进制的对应编号。
比如在VQ中过程是:
假设某一embedding为
q
=
[
0.1
,
0.2
,
−
0.1
,
0.8
,
0.1
]
q = [0.1,0.2,-0.1,0.8,0.1]
q=[0.1,0.2,−0.1,0.8,0.1] 和codebook中的第8个embedding
[
0.2
,
0.2
,
−
0.1
,
0.7
,
0.1
]
[0.2,0.2,-0.1,0.7,0.1]
[0.2,0.2,−0.1,0.7,0.1]相似度最高,那么量化后的结果就是
q
′
=
[
0.2
,
0.2
,
−
0.1
,
0.7
,
0.1
]
,
t
o
k
e
n
_
i
d
=
8
q'=[0.2,0.2,-0.1,0.7,0.1], token\_id=8
q′=[0.2,0.2,−0.1,0.7,0.1],token_id=8
在LFQ中过程是:
某一embedding为
q
=
[
0.1
,
0.2
,
−
0.1
,
0.8
,
0.1
]
q = [0.1,0.2,-0.1,0.8,0.1]
q=[0.1,0.2,−0.1,0.8,0.1],经过符号函数比如torch.sign()将大于0的值变为1,小于等于0的值为-1。
s
i
g
n
_
q
=
[
1
,
1
,
−
1
,
1
,
1
]
sign\_q = [1,1,-1,1,1]
sign_q=[1,1,−1,1,1],sign_q>0的位置以二进制表示,并且对应的十进制为16+8+2+1= 27。因此无需查表可以得到q的量化结果为
q
′
=
[
1
,
1
,
−
1
,
1
,
1
]
,
t
o
k
e
n
_
i
d
=
27
q'=[1,1,-1,1,1], token\_id=27
q′=[1,1,−1,1,1],token_id=27
代码表示: 参考自https://zhuanlan.zhihu.com/p/679032979
import torch
class LookupFreeQuantizer:
def __init__(self, vocab_size: int=None):
"""
初始化方法
:param vocab_size: 词汇表大小,表示要将实数值张量映射到的整数范围。如果未提供,表示不限定词汇表大小。
"""
self.vocab_size = vocab_size
def sign(self, z: torch.Tensor):
"""
将张量中的每个元素转换为其符号
:param z: 包含实数值的张量
:return: 符号化的张量
"""
q_z = torch.sign(z)
q_z[q_z == 0] = -1 # 将零元素转换为-1
return q_z
def token_index(self, q_z: torch.Tensor):
"""
将符号化值张量转换为整数值张量
:param q_z: 符号化的张量
:return: 整数值张量
"""
indices = (torch.arange(q_z.size(-1), dtype=torch.float32)).to(q_z.device)
tokens = torch.sum(2**indices * (q_z > 0).float(), dim=-1)
return tokens
def quantize(self, z: torch.Tensor):
"""
对实数值张量进行量化
:param z: 包含实数值的张量
:return: 符号化的值和整数值张量
"""
if self.vocab_size is not None:
assert z.size(-1) == torch.log2(self.vocab_size)
q_z = self.sign(z)
index = self.token_index(q_z)
return q_z, index.int()