CLIP模型技术学习
前言
最近多模态太火了,学了MiniCPM-V、ViT和transformer,现在开始学CLIP,笔记记录一下,如果有理解不到位的欢迎批评指正。
不需要下游任务微调,图-文对比学习训练的模型就能胜任下游任务?!
- CLIP出自OpenAI发表在ICML 2021的论文Learning Transferable Visual Models From Natural Language Supervision
- 文中提出了一种图文跨模态对齐的方法(基于对比学习),并且发现训练的模型可以很好地泛化到新的任务
CLIP模型架构
CLIP包含一个文本编码器和一个图像编码器
- 文本编码器:Transformer
- 图像编码器:ViT,或者ResNet
CLIP模型代码结构
- CLIP类的代码,forward()里获取图片和文本特征
- 文本特征提取与bert类似,取了[EOS]对应的向量,self.text_projection是把文本投影到图-文隐空间维度,这里没有加激活函数使用非线性变换
- 提取图片特征后也是会乘一个proj矩阵,让文本和图片特征统一到同一个维度
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND,batch_first的改一下维度,在之前ViT的代码中可以看到
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
return image_features, text_features
CLIP模型代码结构——ResNet图片特征
在CLIP中如果使用Resnet作为图片编码器,相比于原始的ResNet文中做了一点修改,在layer4中后面没有使用avg_pool计算均值,而是使用一个AttentionPool2d层计算加权均值
AttentionPool2d
- CLIP类的代码中,AttentionPool2d里面的query向量是输入特征的均值
- 在多头注意力模块MHA中,输入query大小为[1,N,C],key=value大小为[(HW+1),N,C]
- MHA模块的输出为和输入query是一样的大小为[1,N,C],最终返回的为[N,C]
# AttentionPool2d
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # [N,C,H,W] -> [(HW),N,C]
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # [(HW+1),N,C]
x = x + self.positional_embedding[:, None, :].to(x.dtype) # [(HW+1),N,C]
x, _ = F.multi_head_attention_forward(query=x[:1], key=x, value=x,out_proj_weight=self.c_proj.weight)
return x.squeeze(0)
CLIP 模型代码结构——inference
推理过程
- 如果要做分类任务,先需要写一个prompt,label是分类任务的标签,例如cifar-100分类,需要写This is a photo of a cat / This is a photo of a dog …一百个prompt作为文本描述
- 然后@运算符进行矩阵乘法,计算要推理的图和文本的相似度
- Label可以是训练集中没有的,实现zero-shot
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
CLIP 模型代码结构——train
- 与推理过程类似,计算图文向量的相似度
- 训练过程中恰好相似度矩阵的对角线就是图文匹配的label
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_tokens)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
labels = torch.arange(len(logits_per_image)).to(logits_per_image.device)
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2
下一篇
下一篇应该可以看看SAM、MiniCPM-V里面压缩图片编码时采用的类似Q-former结构、SigLip或者MiniCPM-V的文本LLM了