【多模态】CLIP模型技术学习

前言

最近多模态太火了,学了MiniCPM-V、ViT和transformer,现在开始学CLIP,笔记记录一下,如果有理解不到位的欢迎批评指正。

不需要下游任务微调,图-文对比学习训练的模型就能胜任下游任务?!

  • CLIP出自OpenAI发表在ICML 2021的论文Learning Transferable Visual Models From Natural Language Supervision
  • 文中提出了一种图文跨模态对齐的方法(基于对比学习),并且发现训练的模型可以很好地泛化到新的任务
    在这里插入图片描述

CLIP模型架构

CLIP包含一个文本编码器和一个图像编码器

  • 文本编码器:Transformer
  • 图像编码器:ViT,或者ResNet
    在这里插入图片描述

CLIP模型代码结构

  • CLIP类的代码,forward()里获取图片和文本特征
  • 文本特征提取与bert类似,取了[EOS]对应的向量,self.text_projection是把文本投影到图-文隐空间维度,这里没有加激活函数使用非线性变换
  • 提取图片特征后也是会乘一个proj矩阵,让文本和图片特征统一到同一个维度
def encode_image(self, image):
    return self.visual(image.type(self.dtype))
def encode_text(self, text):
    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
    x = x + self.positional_embedding.type(self.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND,batch_first的改一下维度,在之前ViT的代码中可以看到
    x = self.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = self.ln_final(x).type(self.dtype)
    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
    return x
def forward(self, image, text):
    image_features = self.encode_image(image)
    text_features = self.encode_text(text)
    return image_features, text_features

CLIP模型代码结构——ResNet图片特征

在CLIP中如果使用Resnet作为图片编码器,相比于原始的ResNet文中做了一点修改,在layer4中后面没有使用avg_pool计算均值,而是使用一个AttentionPool2d层计算加权均值
在这里插入图片描述
AttentionPool2d

  • CLIP类的代码中,AttentionPool2d里面的query向量是输入特征的均值
  • 在多头注意力模块MHA中,输入query大小为[1,N,C],key=value大小为[(HW+1),N,C]
  • MHA模块的输出为和输入query是一样的大小为[1,N,C],最终返回的为[N,C]
# AttentionPool2d
def forward(self, x):
    x = x.flatten(start_dim=2).permute(2, 0, 1)  # [N,C,H,W] -> [(HW),N,C]
    x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # [(HW+1),N,C]
    x = x + self.positional_embedding[:, None, :].to(x.dtype)  # [(HW+1),N,C]
    x, _ = F.multi_head_attention_forward(query=x[:1], key=x, value=x,out_proj_weight=self.c_proj.weight)
    return x.squeeze(0)

CLIP 模型代码结构——inference

推理过程

  • 如果要做分类任务,先需要写一个prompt,label是分类任务的标签,例如cifar-100分类,需要写This is a photo of a cat / This is a photo of a dog …一百个prompt作为文本描述
  • 然后@运算符进行矩阵乘法,计算要推理的图和文本的相似度
  • Label可以是训练集中没有的,实现zero-shot
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)
    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

CLIP 模型代码结构——train

  • 与推理过程类似,计算图文向量的相似度
  • 训练过程中恰好相似度矩阵的对角线就是图文匹配的label
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_tokens)

image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()

labels = torch.arange(len(logits_per_image)).to(logits_per_image.device)
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss  = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2

下一篇

下一篇应该可以看看SAM、MiniCPM-V里面压缩图片编码时采用的类似Q-former结构、SigLip或者MiniCPM-V的文本LLM了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值