CLIP
CLIP(Contrastive Language-Image Pre-Training,对比语言-图像预训练)是一种在各种图像文本对上训练的神经网络。它可以用自然语言进行指示,以预测最相关的文本片段,给定图像,而无需直接针对任务进行优化,类似于 GPT-2 和 GPT-3 的zero-shot功能。
CLIP 在 ImageNet 上zero-shot(无需使用任何原始的 1.28M 标记样本)与原始 ResNet50 的性能相当。
摘要
SOTA的CV系统,固定一组标签和分类头,需要新增标记数据来指定新类,这限制了拓展
文本相对固定的标签来说,可作为更广泛的监督信息
对于图像,预测相匹配的caption可作为简单有效且可拓展的方法
- 训练:400 million 互联网上的图像文本对 training from scratch
- 推理:用语言描述概念(类),实现下游任务的 zero-shot
模型迁移能力优秀,取得了与完全监督的Baseline可比的结果
引言
近几年,文本预训练(自回归、掩码)彻底改变了NLP,其在计算、模型容量、数据方面扩展了许多数量级。使得NLP上预训练模型能zero-shot迁移到下游数据集。但目前计算机视觉还为探索
已有大量工作探索图像和文本之间的联系,但使用自然语言进行图像表征学习的工作仍然很少,其原因可能是测试基准问题。
以前的工作,没有大量的原始文本,因此倾向于在有限的监督与几乎无限的原始文本之间妥协。这与NLP中的预训练关键区别就是规模。
提出 Contrastive Language -Image Pre-training (CLIP),CLIP,类似于GPT家族,在预训练期间学习执行广泛的任务,包括OCR,地理定位,动作识别等。
CLIP取得了与特定于任务(task-specific)的监督模型可比的结果。
CLIP优于SOTA的ImageNet模型,且具有更高的计算效率和鲁棒性。
方法
本部分来源于CLIP论文,主要讲述了CLIP的动机和效果优秀的解释,详细的方法过程主要集中在后面的代码部分。
自然语言监督
相比于标签式的监督,用自然语言进行监督更加灵活,且将能学习与语言相关联的表征,从而实现灵活的零样本迁移。
创建足够大的数据集
数据集,其中包含4亿对(图像,文本)对,这些数据来自互联网上各种公开可用的资源。
我们搜索(图像,文本)对作为构建过程的一部分,其文本包含500,000个查询的集合(基本查询列表是所有在英文版维基百科中出现至少100次的单词。这是用bi-gram增广的)。通过每个查询包含多达20,000对(图像、文本)来平衡结果。
选择有效的预训练方法
训练效率是成功扩展自然语言监督的关键,因此基于该指标选择最终的预训练方法。
图像到文本的预测任务艰巨,学习效率低下,而对比可以更好的表征。图像的生成模型相比同性能的对比模型,需要超过一个数量级的计算量。
因此选择的预训练任务为:预测哪个文本与哪张图像配对,而不是该文本的确切单词。
CLIP通过联合训练图像编码器和文本编码器来学习多模态嵌入空间,以最大化批处理中 N N N对真实对的图像和文本嵌入的余弦相似度,同时最小化 N 2 − N N^2−N N2−N对错误对的嵌入的余弦相似度。在这些相似性得分上优化对称交叉熵损失。
相较于对比学习,CLIP从头训练而非利用ImageNet权重初始化的图像编码器或者文本编码器。只利用线性投影将每个编码器的表征映射到多模态嵌入空间。
对于文本,未采用均匀单句采样,因为描述多为单句
对于图像,训练时只使用random square crop 和 resize两种增强
模型结构
图像编码器:
- 架构1
- ResNet作为基础架构
- 采用ResNet-D改进
- 使用rect-2 blur pooling
- 将global average pooling 替换为 attention pooling 机制:单层tranormer式的多头注意力,其中Q为图像的全局
平均池化 - 模型缩放时,同时增加深度、宽度和分辨率
- 架构2
- ViT,仅对transformer 前面的 combined patch and position embeddings 添加了额外的layer normalization,并使用略有不同的初始化方案
文本编码器:transformer架构,参照修改的架构Language models are unsupervised multitask learners
- base 64M参数 12层 512-wide 8 heads
- vocab size:49,152
- max sequence length:76
- start and end tokens:[SOS] [EOS],transformer最后一层[EOS]处的激活作为文本的表征,经过层归一化后投影到多模态嵌入空间中。
- masked self-attention
补充内容:
(1)Transformer encoder layer 接受一个序列,输出一个相同尺寸的序列
(2)Text encoder
基于Transformer架构,在最后一个token处使用线性投影得到特征
(2)Image encoder
基于ViT架构,将图像分块嵌入,并在最前面加入[CLS] token,在该token对应的最后一层使用线性投影得到特征
代码阅读
CLIP 分为训练和zero-shot两个阶段,为了利用自然语言来进行监督,CLIP在训练阶段进行:
- 利用 text encoder 和 image encoder 得到视觉和文本表征 I I I, T T T
- 计算 I @ T I @ T I@T (矩阵乘法),得到相似度矩阵
- 要求对角线上(正确匹配)的对具有较高相似度,而非对角线上具有低相似度。
针对CLIP,本文只介绍一些值得注意的部分代码。
掩码注意力
首先,简要说明论文中描述的掩码注意力(因果注意力)
def build_causal_mask(self):
'''因果注意力掩码。'''
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
- 首先,代码创建了一个大小为self.num_pos × \times ×self.num_pos 的空白矩阵mask,并用float(“-inf”)填充。
- 然后,代码将mask矩阵的上三角部分设置为负无穷大,以表示在注意力计算过程中,模型不应关注当前位置之后的 token。
- 最后,代码将mask矩阵的上三角部分向下移动一行,以使其与序列中的 token 对齐。这样,模型在处理序列时将只关注当前位置之前的 token。
标签生成
CLIP对一个batch的图像文本对进行编码,然后利用矩阵乘法得到图像与文本的相似度矩阵。为了让网络能预测图像与文本的匹配,CLIP要求对角线上(正确匹配)的相似度尽可能高,非对角线上(错误匹配)的相似度低。具体来说,CLIP为每一行(或每一列)生成一个标签,对角线的索引就行标签的值。
即一个batch的标签为:
[
1
,
2
,
3
,
.
.
.
,
n
]
[1, 2, 3, ..., n]
[1,2,3,...,n](行列相同)
为了利用交叉熵来进行对比学习,先生成标签
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
'''核心代码:为每个图像文本对生成序号作为标签'''
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
结果展示:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
logit_scale:可学习参数
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
'''利用矩阵乘法计算图像特征和文本特征之间的相似度'''
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward(self, image_features, text_features, logit_scale, output_dict=False):
device = image_features.device
'''计算图像特征和文本特征之间的相似度。'''
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
'''序号标签'''
labels = self.get_ground_truth(device, logits_per_image.shape[0])
'''损失'''
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return {"contrastive_loss": total_loss} if output_dict else total_loss
为了在已确定的计算资源上增大batch,从而或得更好的对比性能,CLIP做了适当的修改来进行梯度累计。
梯度累积(Gradient Accumulation):将多个mini-batch的样本前向加后向得到梯度后,累积起来,计算平均梯度再更新参数
在CLIP中,对于一个完整的batch,拆分成多个mini-batch后,各个mini-batch内部天然能够计算相似度,但mini-batch之间不能,因此,直接进行梯度累积将损失这部分的对比。
为了增加batch,获得更多的对比,CLIP先进行accum_freq个mini-batch的推理(不记录梯度),获得一个标量的历史值,然后重新对每一个mini-batch做记录梯度的前向传播,在拼接原来的accum_freq-1个历史值,从而达到更大的对比量。如下图所示。
for i, batch in enumerate(dataloader):
images, texts = batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
''' 缓存历史特征,不进行 gradient tracking '''
with torch.no_grad():
with autocast():
model_out = model(images, texts)
for f in ("logit_scale", "logit_bias"):
model_out.pop(f, None)
for key, val in model_out.items():
if key in accum_features:
accum_features[key].append(val)
else:
accum_features[key] = [val]
accum_images.append(images)
accum_texts.append(texts)
''' 未达到累积的batch时,跳过下面的参数更新 '''
if ((i + 1) % args.accum_freq) > 0:
continue
optimizer.zero_grad()
for j in range(args.accum_freq):
images = accum_images[j]
texts = accum_texts[j]
with autocast():
model_out = model(images, texts)
inputs_no_accum = {}
inputs_no_accum["logit_scale"] = logit_scale = model_out.pop("logit_scale")
if "logit_bias" in model_out:
inputs_no_accum["logit_bias"] = model_out.pop("logit_bias")
inputs = {}
for key, val in accum_features.items():
accumulated = accum_features[key]
''' 关键步骤:拼接有梯度的记录的当前batch嵌入和标量嵌入,同时保持位置关系'''
inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:])
losses = loss(**inputs, **inputs_no_accum, output_dict=True)
del inputs
del inputs_no_accum
total_loss = sum(losses.values())
losses["loss"] = total_loss
backward(total_loss, scaler)
if args.accum_freq > 1:
accum_images, accum_texts, accum_features = [], [], {}
推理阶段
针对下游任务,CLIP无需进行分类头的训练或微调,直接进行Zero-Shot推理:
- 设计第 i i i 类对应的提示词:a photo of a [CLS]
- 计算图像与提示词的匹配程度:
p ( y = i ∣ x ) = exp ( ⟨ w i , z ⟩ / τ ) ∑ j = 1 K exp ( ⟨ w i , z ⟩ / τ ) p(y=i|x)=\frac{\exp({\langle w_i, z\rangle}/ \tau)}{\sum_{j=1}^{K} \exp(\langle w_i, z\rangle/ \tau)} p(y=i∣x)=∑j=1Kexp(⟨wi,z⟩/τ)exp(⟨wi,z⟩/τ)
只需要在标签空间中寻找与图像匹配程度最高的文本,就能确定标签。