BEIT V2: Masked Image Modeling with Vector-Quantized Visual Tokenizers
使用 Vector-Quantized 视觉标记的 Masked Image Modeling
code: https://aka.ms/beitv2
摘要
观点:大多数现有的研究都是针对低级图像像素的,这阻碍了对表示模型的高级语义的开发。
改进:使用语义丰富的 visual tokenizer 作为 masked 预测的重建目标,为将MIM从像素级提升到语义级提供了一种系统的方法。
具体:vector-quantized 的知识提取来训练 tokenizer ,该 tokenizer 将连续的语义空间离散化为紧凑的代码。然后,我们通过预测 masked image patches 的原始 visual tokens 来预训练视觉 Transformers 。
效果:在图像分类和语义分割方面的实验表明,BEITV2优于所有比较的MIM方法。在ImageNet-1K(224尺寸)上,基本尺寸BEIT V2实现了85.5%的微调精度和80.1%的线性预测精度。大尺寸BEIT V2 微调中top-1准确率 87.3%,在ADE20K上mIoU 56.7%
1. Introduction
现有的MIM方法可以根据重建目标大致分为三类:
- 低级图像像素(例如,MAE;CIM)
- 手工特征(例如,HOG特征;MaskFeat)
- visual tokens : BEiT; PECO;
我们在ImageNet-1k上对 base- and large-size vision Transformers 进行了自监督学习,并对下游任务进行了评估,例如图像分类、线性预测和语义分割。如图1所示
贡献总结如下:
我们提出了 Vector-Quantized 知识提取,将MIM从像素级提升到语义级,用于自监督表示学习。
我们引入了一种 patch 聚合策略,该策略在给定离散语义 token 的情况下强制执行全局结构,并提高了学习表示的性能。
我们对下游任务进行了广泛的实验,包括ImageNet微调、线性预测和语义分割。实验结果表明,所提出的方法显著提高了模型大小、训练步骤和下游任务的性能。
2. Methodology
图像表示
The vision Transformers 被用作 backbone ,以获得图像表示。
The input image (H×W ×C)is reshaped to N = H*W /P^2 patches
其中:(P, P ) is the patch size
在实验中,每个224×224的图像被分割成14×14的图像 patches 网格,其中每个patch是16×16。图像patches被 flattened 并线性投影成embeddings,输入到 Transformers
Vector Quantize & CodeBook [非论文部分]
简介
- 本部分参考:vector_quantize_pytorch源码解析——VectorQuantize - 知乎
- 代码地址:GitHub - lucidrains/vector-quantize-pytorch: Vector Quantization, in Pytorch
Vector Quantize,中文翻译为 向量量化,它其实跟embedding很类似。embedding是根据索引来找内部的code table。
在本文中则是 Vector-Quantized 在codebook 中查找每个 patch 的最近邻居 vj,设 codebook embeddings :{ v1,v2,··,vK }
VectorQuantize的例子
VectorQuantize是将每一个vector与内部的codebook做计算,然后得出一个quantized vector
下面是一个例子:
import vector_quantize_pytorch as vq
import torch
# 输入为一个三维张量,(1,3,2),它的最后一维必须和 VectorQuantize 初始化中的 dim 一致
a = torch.FloatTensor([-0.1, 0.5, 0.2, 0.33, -0.6, 0.2]).view(1, 3, 2)
print('a=', a)
# codebook_size指定了codebook的大小,所有的vector最终会quantized入这个大小之内。
# 比如这里是6,那么最后每一个vector就会被映射到0至5,共6个数之内,也就是indices输出
quantizer = vq.VectorQuantize(dim=2, codebook_size=6)
# quantized是根据indices找出codebook中的值。所以它跟embedding很大不同是,最后输出的维度(最后一维)是不会变的
quantized, indices, loss = quantizer(a)
print('quantized', quantized)
print('indices', indices)
print('loss', loss)
输出:
简析Vector Quantize源码
内部的codebook
Vector Quantize内部一共有两个codebook的实现,EuclideanCodebook和CosineSimCodebook。
VectorQuantize有一个入参,use_cosine_sim = False,
指定使用哪一个实现。默认的False代表使用 EuclideanCodebook。
我们就来看一下EuclideanCodebook
def kmeans(
samples,
num_clusters,
num_iters = 10,
use_cosine_sim = False,
sample_fn = batched_sample_vectors,
all_reduce_fn = noop
):
num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device
means = sample_fn(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ rearrange(means, 'h n d -> h d n')
else:
dists = -torch.cdist(samples, means, p = 2)
buckets = torch.argmax(dists, dim = -1)
bins = batched_bincount(buckets, minlength = num_clusters)
all_reduce_fn(bins)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype)
new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples)
new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')
all_reduce_fn(new_means)
if use_cosine_sim:
new_means = l2norm(new_means)
means = torch.where(
rearrange(zero_mask, '... -> ... 1'),
means,
new_means
)
return means, bins
class EuclideanCodebook(nn.Module):
def __init__(...):
# 初始化函数,根据入参kmeans_init来决定
init_fn = uniform_init if not kmeans_init else torch.zeros
# 内部的codebook实际上就是这里embed的张量,shape为(num_codebooks, codebook_size, dim)
# Beit v2 : dim = 32
embed = init_fn(num_codebooks, codebook_size, dim)
# 必填入参,表示codebook的大小
self.codebook_size = codebook_size # Beit v2 : 8192
self.num_codebooks = num_codebooks # 1
... ...
@torch.jit.ignore
def init_embed_(self, data):
if self.initted:
return
embed, cluster_size = kmeans(
data,
self.codebook_size,
self.kmeans_iters,
sample_fn = self.sample_fn,
all_reduce_fn = self.kmeans_all_reduce_fn
)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(torch.Tensor([True]))
@autocast(enabled = False)
def forward(self, x):
... ...
# x 改变维度
flatten = rearrange(x, 'h ... d -> h (...) d')
self.init_embed_(flatten)
... ...
# 计算 embed 与 flatten 两个矩阵之间的距离,因为要求 argmax,所以取了负值
# Beit v2 : flatten的shape是(1, N, 32), embed的shape是(1, 8192 , 32)
# dist的shape是(1,N, 8192), 它表示的是flatten里每个vector到embed每个vector的距离
dist = -torch.cdist(flatten, self.embed, p = 2)
# argmax 分别找到距离最小的, 有N个向量,所以emd_ind的shape为(1,N)
embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
# 转成one_hot,shape为(1,N,8192)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
# 从(1,N)转成(1,1,N)
embed_ind = embed_ind.view(*shape[:-1])
# 根据embed_ind(indices)来在self.embed(codebook)找对应的映射值,其实这一步就跟nn.embedding很像了
quantize = batched_embedding(embed_ind, self.embed)
# training=True是默认值,所以一般会走进来
if self.training:
# cluster_size的shape为(1,8192) 表示每个code在本次出现的个数
cluster_size = embed_onehot.sum(dim = 1)
self.all_reduce_fn(cluster_size)
# self.cluster_size默认为zero张量,在这里会逐渐跟新
ema_inplace(self.cluster_size, cluster_size, self.decay)
# 这个运算就是 embed_onehot 乘 flatten,矩阵乘法
# flatten的shape为(1,N,32), embed_onehot的shape为(1,N,8192)
# embed_sum的shape为(1,8192,32)
# 这里乘法的意义应该是找出最小indecis数组中两个维度的weights,
# 然后紧接着用于更新到embed_avg中
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
# update
ema_inplace(self.embed_avg, embed_sum, self.decay)
# 拉普拉斯平滑
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
return quantize, embed_ind
VectorQuantize forward
class VectorQuantize(nn.Module):
def __init__(...):
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
... ...
def forward(self,x,mask = None):
# 按 Beit v2 的例子,shape=(1,N,32)
if self.accept_image_fmap: # 二维图片作为输入,要改维度
height, width = x.shape[-2:]
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.project_in(x)
# codebook即为EuclidianCodebook或者CosineSimCodebook
# 计算得到codebook中的映射,以及indices
quantize, embed_ind = self._codebook(x)
loss = torch.tensor([0.], device = device, requires_grad = self.training)
if self.training:
# 默认为True
# 加上残值
quantize = x + (quantize - x).detach()
if self.commitment_weight > 0: # 训练过程中loss传递的比例
if exists(mask):
# with variable lengthed sequences
commit_loss = F.mse_loss(quantize, x, reduction = 'none')
commit_loss = commit_loss[mask].mean()
else:
commit_loss = F.mse_loss(quantize, x)
loss = loss + commit_loss * self.commitment_weight
if self.orthogonal_reg_weight > 0:
codebook = self._codebook.embed
if 仅计算正交损失:
unique_code_ids = torch.unique(embed_ind)
codebook = codebook[unique_code_ids]
num_codes = codebook.shape[0]
if 存在正交code数量限制 and num_codes > 正交code数量限制:
rand_ids = torch.randperm(num_codes, device = device)[:正交code数量限制]
codebook = codebook[rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
quantize = self.project_out(quantize)
if self.accept_image_fmap: # 图像还原
quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width)
embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width)
return quantize, embed_ind, loss
Training Visual Tokenizer
如图:visual tokenizer 的训练Pipeline ,训练后,每个图像被转换为离散的 visual token,Vector-Quantized 知识蒸馏(VQ-KD)来训练 visual tokenizer
- visual tokenizer 由 视觉 Transformers encoder 和量化器组成。
- Transformers encoder 首先将输入图像编码为向量。灰色条: z = [z1,z2,··,zN]
- V 视觉词汇表(又称 codebook )包含K个离散 codebook embeddings
- Vector-Quantized 在codebook 中查找每个 patch 表示hi 的最近邻居vj。设{v1,v2,··,vK} 表示 codebook embeddings,对于第i个图像patch,其量化代码计算为。
其中:
1)j∈{1,2,··,K}
2)L2一化用于 codebook 查找
3)上述距离相当于根据余弦相似度找到code。
- 在将图像量化为 visual token 之后,我们将L2-normalized codebook embeddings 馈送到解码器。
-
解码器也是一个多层Transformer。
-
输出向量oi 旨在重建 teacher 模型的语义特征,例如DINO;CLIP。设ti表示第i个图像patch 的 teacher 模型的特征向量。
-
在训练过程中,我们最大化解码器输出oi和 teacher 指导 ti之间的余弦相似性。
- 因为量化过程(方程1)是不可微分的,所以梯度直接从解码器输入复制到编码器输出,图2,以将梯度反向传播到编码器。
- 量化器为每个编码器输出查找最近的code,而 codebook embeddings 的梯度指示编码器的有用优化方向。
VQ-KD的训练目标定义为:公式(2)
- 其中sg[·]代表停止梯度算子,该算子前向传时是恒等式,而在后向传时具有零梯度。
- D代表用于 tokenizer 训练的图像数据。
# 构建模型 注册到timm -- modeling_vqkd.py
from timm.models.registry import register_model
# 按配置构建 vqkd
@register_model
def vqkd_encoder_base_decoder_3x768x12_clip(...):
encoder_config, decoder_config = get_model_default_params(), get_model_default_params()
# encoder settings (encoder_config 字段填充)
...
# decoder settings (decoder_config 字段填充)
...
# teacher settings
teacher_model_type = 'clip'
decoder_out_dim = 512
model = VQKD(encoder_config, decoder_config, n_code, code_dim, teacher_model_type=teacher_model_type,
decoder_out_dim=decoder_out_dim, **kwargs)
return model
# vqkd 模型
class VQKD(nn.Module):
def __init__(self,
encoder_config,
decoder_config,
n_embed=8192,
embed_dim=32,
decay=0.99,
process_type='default',
quantize_kmeans_init=True,
teacher_model_type='clip',
decoder_out_dim=512,
rec_loss_type='cosine',
**kwargs
):
# 编码器
self.encoder = VisionTransformer(**encoder_config)
# task layer
self.encode_task_layer = nn.Sequential(
nn.Linear(encoder_config['embed_dim'], encoder_config['embed_dim']),
nn.Tanh(),
nn.Linear(encoder_config['embed_dim'], embed_dim) # for quantize)
# 见 NormEMAVectorQuantizer 源码解析
self.quantize = NormEMAVectorQuantizer(n_embed=n_embed, embedding_dim=embed_dim, beta=1.0, kmeans_init=quantize_kmeans_init, decay=decay,)
# 解码器
self.decoder = VisionTransformer(**decoder_config)
self.decode_task_layer = nn.Sequential(
nn.Linear(decoder_config['embed_dim'], decoder_config['embed_dim']),
nn.Tanh(),
nn.Linear(decoder_config['embed_dim'], self.decoder_out_dim),)
def encode(self, x):
encoder_features = self.encoder(x, return_patch_tokens=True) # ViT
to_quantizer_features = self.encode_task_layer(encoder_features.type_as(self.encode_task_layer[-1].weight))
N = to_quantizer_features.shape[1]
h, w = int(math.sqrt(N)), int(math.sqrt(N))
to_quantizer_features = rearrange(to_quantizer_features, 'b (h w) c -> b c h w', h=h, w=w) # reshape for quantizer
quantize, loss, embed_ind = self.quantize(to_quantizer_features)
return quantize, embed_ind, loss
def forward(self, x, **kwargs):
"""
x: shape [B, 3, H, W] in [0, 1]
"""
x = self.pre_process(x) # 图像归一化预处理 rescale to [-1, 1]
target = self.get_regress_target(x, **kwargs) # target = teacher_clip_model.encode(x)
quantize, embed_ind, emb_loss = self.encode(x) # students 编码器
xrec = self.decode(quantize) # students 解码器
# loss
rec_loss = self.calculate_rec_loss(xrec, target) # 计算 cosine 损失
loss = emb_loss + rec_loss
return loss, log
训练
# 来自训练代码 run_vqkd_training.py 通过 timm 加载模型
# parser.add_argument('--model', default='vqkd_encoder_base_decoder_3x768x12_clip'
# parser.add_argument('--codebook_n_emd', default=8192
# parser.add_argument('--codebook_emd_dim', default=32
# parser.add_argument('--quantize_kmeans_init', action='store_true'
# parser.add_argument('--input_size', default=224
# -- regress feature
# parser.add_argument('--teacher_model_type', default='clip'
# parser.add_argument('--teacher_input_size', default=224
model = timm.models.create_model('vqkd_encoder_base_decoder_3x768x12_clip')
提高 codebook 利用率
- vector quantization 训练的一个常见问题是 codebook 崩溃。
- 缓解这个问题:指数移动平均用于更新 codebook embedding 。
NormEMAVectorQuantizer 源码解析
注释:
torch.randperm(n):将0~n-1(包括0和n-1)随机打乱后获得的数字序列,函数名是random permutation缩写
【sample】
>>> torch.randperm(10)
tensor([2, 3, 6, 7, 8, 9, 1, 5, 0, 4])
文件:norm_ema_quantizer.py
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num: # num = 8192
indices = torch.randperm(num_samples, device = device)[:num] # 函数意思详见注释
else:
indices = torch.randint(0, num_samples, (num,), device = device) # batchsize = 64 所以 是长度为 8192 的向量,数值范围是 0~63
return samples[indices] # 将样本随机扩展到 8192
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
# 只是一个 batch 的 samples
# num_clusters = 8192
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
means = sample_vectors(samples, num_clusters) # 将样本随机扩展到 8192
for _ in range(num_iters): # 迭代10次 kmeans
if use_cosine_sim:
dists = samples @ means.t() # -------------------------------------20230418 xws torch.matmul
else:
... ...
buckets = dists.max(dim = -1).indices
bins = torch.bincount(buckets, minlength = num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype)
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples)
new_means = new_means / bins_min_clamped[..., None]
if use_cosine_sim:
new_means = l2norm(new_means)
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EmbeddingEMA(nn.Module):
def __init__(...):
...
if codebook_init_path == '':
if not kmeans_init: # 随机初始化
weight = torch.randn(num_tokens, codebook_dim)
weight = l2norm(weight)
else: # 因为要用 kmeans init 所以初始化为 0
weight = torch.zeros(num_tokens, codebook_dim)
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
else:
codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
weight = codebook_ckpt_weight.clone()
self.register_buffer('initted', torch.Tensor([True]))
self.weight = nn.Parameter(weight, requires_grad = False)
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
self.update = True
@torch.jit.ignore
def init_embed_(self, data):
if self.initted:
return
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim = True)
self.weight.data.copy_(embed)
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(torch.Tensor([True]))
def forward(self, embed_id):
return F.embedding(embed_id, self.weight)
class NormEMAVectorQuantizer(nn.Module):
def __init__(...):
# setting
self.codebook_dim = 32
self.num_tokens = 8192
self.beta = 1.0
self.decay = 0.99
kmeans_init = True
self.statistic_code_usage=True
codebook_init_path = ''
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, 1e-5, kmeans_init, codebook_init_path)
self.register_buffer('cluster_size', torch.zeros(n_embed)) # [],shape 8192
self.all_reduce_fn = nn.Identity()
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
# z, 'b c h w -> b h w c'
z = rearrange(z, 'b c h w -> b h w c')
z = l2norm(z)
z_flattened = z.reshape(-1, self.codebook_dim) # (b, 32)
self.embedding.init_embed_(z_flattened)
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
if not self.training:
with torch.no_grad():
cluster_size = encodings.sum(0)
self.all_reduce_fn(cluster_size)
ema_inplace(self.cluster_size, cluster_size, self.decay)
if self.training and self.embedding.update:
#EMA cluster size
bins = encodings.sum(0)
self.all_reduce_fn(bins)
# self.embedding.cluster_size_ema_update(bins)
ema_inplace(self.cluster_size, bins, self.decay)
zero_mask = (bins == 0)
bins = bins.masked_fill(zero_mask, 1.)
embed_sum = z_flattened.t() @ encodings
self.all_reduce_fn(embed_sum)
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
embed_normalized = l2norm(embed_normalized)
embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
embed_normalized)
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
#z_q, 'b h w c -> b c h w'
z_q = rearrange(z_q, 'b h w c -> b c h w')
return z_q, loss, encoding_indices
预训练BEIT V2
我们遵循BEIT中的MIM设置来预训练视觉Transformer:
- 给定输入图像x,大约40%的图像 patch 被masked。masked 位置被称为M。
- 然后, Mask 部分用 表示的特征代替原图
- 随后,我们为输入准备了一个可学习的[CLS] token 送到视觉 Transformer。
- 最后的编码向量表示为 h 长度为N,其中h0表示[CLS] token 。
- 接下来,我们将MIM head 实例化为一个简单的全连接层。对于每个掩蔽位置,softmax分类器预测视觉标记 p=softmax(Wh+b)
- MIM的训练损失定义为:其中zi表示原始图像的视觉 token ,D表示预训练图像
Pretraining global representation.
我们为 global 图像表示 预处理 [CLS] token。目标是减轻 patch 级预训练和图像级表示聚合之间的差异。
如图3所示,配备了patch 聚合的MIM框架, 构建了一个表示 bottleneck ,以鼓励[CLS] token 尽可能多地收集信息。
对于L层Transformer,第L层的输出向量,其中L∈{1,2,··,L}。
为了预训练 最后一层的[CLS]令牌 ,我们将其与中间第l层的patch向量 连接 形成 S 。
然后,我们将 S 馈送到浅(例如,两层)Transformer解码器,并再次进行masked 预测
注意,这两个MIM头的参数是共享的,并且在 masked 位置也计算MIM损耗,如方程3所示。
预训练损失是LMIM和LcMIM的总和。损失项LcMIM明确鼓励[CLS] token 将 patch 信息聚合为全局表示。
该模型倾向于将全局信息推送到h_CLS。
信息流 bottleneck 促使[CLS] token 比未经训练的 token 更可靠地使用全局表示。
请注意,新添加的浅层解码器仅用于预训练[CLS] token
Experiments
BEIT V2 > CAE