【多模态】CLIP模型详解

目录

🍋🍋1.网络整体结构

🍑🍑1.1预训练阶段​编辑

🍑🍑1.2提取预测类别文本

🍑🍑1.3推理预测

🍇🍇2.代码实现

🍏🍏2.1推理代码

🍏🍏2.2zero-shot预测CIFAR100数据集

整理不易,欢迎一键三连!!!


        最近在看多模态的知识,发现怎么都绕不过CLIP模型,所以一探究竟。CLIP 全称 Contrastive Language-Image Pre-training,具有十分强悍的迁移学习能力,为了佐证这个能力,在超过 30 多个视觉数据上进行测试,涵盖面十分广泛,包括 OCR、视频动作检测、坐标定位和许多细分类任务,在所有的结果中最炸裂的一条就是在 ImageNet 上的结果,CLIP 在不使用任意一张 ImageNet 图片训练的情况下,直接 Zero-Shot 推理,就能获得跟之前有监督训练的 ResNet-50 同样优秀的结果,Clip模型还可以用于图像生成、图像检索、视觉问答、视频理解等任务,因此对多模态学习中具体深刻的历史意义。

 这效果还有谁!
🐾🐾论文地址:paper
🐾🐾代码地址:code

🍋🍋1.网络整体结构

如图所示,CLIP共有3个阶段        

  • Contrastive pre-training:预训练阶段,使用图片 - 文本对进行对比学习训练;
  • Create dataset classifier from label text:提取预测类别文本特征;
  • Use for zero-shot predictiion:进行 Zero-Shot 推理预测;

模型伪代码如下:

🍑🍑1.1预训练阶段

        在预训练阶段,对比学习中正样本对和负样本的定义为能够配对的图片-文本对,和不能匹配的图片-文本对。具体来说,先分别对图像和文本提特征,这时图像对应生成 I1、I2 ... In 的特征向量,文本对应生成 T1、T2 ... Tn 的特征向量,然后中间对角线为正样本,其余均为负样本。这样的话就形成了 n 个正样本,n^2 - n 个负样本,有了正负样本,模型就可以通过对比学习的方式训练起来了,完全不需要手工的标注。当然,自监督的训练需要大量的数据,OPEN AI 的这个训练数据量大约在 4亿个的数量级,数据来源均来自于互联网。

🍑🍑1.2提取预测类别文本

        由于CLIP 预训练时候的文本端输出输入的是个句子,但原始的类别都是句子,因此首先需要对文本类别进行一些单词转句子的处理,如法如下:使用 A photo of a {object}. 的提示模板 (prompt template) 进行构造,比如对于 dog,就构造成 A photo of a dog.,然后再送入 Text Encoder 进行特征提取,这样就会得到一个文本的特征向量。

🍑🍑1.3推理预测

        模型推理比较简单,只需要将输入图片传给ImagesEncoder模块,就会生成一个一维的图片特征向量,然后拿这个图片特征和 第二阶段生成的文本特征做余弦相似度对比,最相似的即为我们想要的那个结果,比如这里应该会得到 A photo of a dog. 


🍇🍇2.代码实现

🍏🍏2.1推理代码

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)    #图像预处理
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)     #文本token化

with torch.no_grad():
    image_features = model.encode_image(image)      #图像编码
    text_features = model.encode_text(text)         #文本编码
    
    logits_per_image, logits_per_text = model(image, text)  #模型预测
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()    #GPUtensor转CPUnumpy

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

🍏🍏2.2zero-shot预测CIFAR100数据集

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)#下载数据集

# Prepare the inputs
image, class_id = cifar100[3637]  #取其中一张图
image_input = preprocess(image).unsqueeze(0).to(device)   #图像预处理
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in  cifar100.classes]).to(device)  #文本预处理

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)   #图像编码器编码
    text_features = model.encode_text(text_inputs)     文本编码器编码
 
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)  #图像特征归一化
text_features /= text_features.norm(dim=-1, keepdim=True)    #文本特征归一化
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)   #图像-文本相似度计算
values, indices = similarity[0].topk(5)  #打印相似度top5

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

输出结果如下:

Top predictions:

           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%

整理不易,欢迎一键三连!!!

送你们一条美丽的--分割线--


🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷

  • 5
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zy_destiny

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值