目录
最近在看多模态的知识,发现怎么都绕不过CLIP模型,所以一探究竟。CLIP 全称 Contrastive Language-Image Pre-training,具有十分强悍的迁移学习能力,为了佐证这个能力,在超过 30 多个视觉数据上进行测试,涵盖面十分广泛,包括 OCR、视频动作检测、坐标定位和许多细分类任务,在所有的结果中最炸裂的一条就是在 ImageNet 上的结果,CLIP 在不使用任意一张 ImageNet 图片训练的情况下,直接 Zero-Shot 推理,就能获得跟之前有监督训练的 ResNet-50 同样优秀的结果,Clip模型还可以用于图像生成、图像检索、视觉问答、视频理解等任务,因此对多模态学习中具体深刻的历史意义。
这效果还有谁!
🐾🐾论文地址:paper
🐾🐾代码地址:code
🍋🍋1.网络整体结构
如图所示,CLIP共有3个阶段
- Contrastive pre-training:预训练阶段,使用图片 - 文本对进行对比学习训练;
- Create dataset classifier from label text:提取预测类别文本特征;
- Use for zero-shot predictiion:进行 Zero-Shot 推理预测;
模型伪代码如下:
🍑🍑1.1预训练阶段
在预训练阶段,对比学习中正样本对和负样本的定义为能够配对的图片-文本对,和不能匹配的图片-文本对。具体来说,先分别对图像和文本提特征,这时图像对应生成 I1、I2 ... In 的特征向量,文本对应生成 T1、T2 ... Tn 的特征向量,然后中间对角线为正样本,其余均为负样本。这样的话就形成了 n 个正样本,n^2 - n 个负样本,有了正负样本,模型就可以通过对比学习的方式训练起来了,完全不需要手工的标注。当然,自监督的训练需要大量的数据,OPEN AI 的这个训练数据量大约在 4亿个的数量级,数据来源均来自于互联网。
🍑🍑1.2提取预测类别文本
由于CLIP 预训练时候的文本端输出输入的是个句子,但原始的类别都是句子,因此首先需要对文本类别进行一些单词转句子的处理,如法如下:使用 A photo of a {object}.
的提示模板 (prompt template) 进行构造,比如对于 dog,就构造成 A photo of a dog.
,然后再送入 Text Encoder 进行特征提取,这样就会得到一个文本的特征向量。
🍑🍑1.3推理预测
模型推理比较简单,只需要将输入图片传给ImagesEncoder模块,就会生成一个一维的图片特征向量,然后拿这个图片特征和 第二阶段生成的文本特征做余弦相似度对比,最相似的即为我们想要的那个结果,比如这里应该会得到 A photo of a dog.
🍇🍇2.代码实现
🍏🍏2.1推理代码
import torch
import clip
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device) #图像预处理
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) #文本token化
with torch.no_grad():
image_features = model.encode_image(image) #图像编码
text_features = model.encode_text(text) #文本编码
logits_per_image, logits_per_text = model(image, text) #模型预测
probs = logits_per_image.softmax(dim=-1).cpu().numpy() #GPUtensor转CPUnumpy
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
🍏🍏2.2zero-shot预测CIFAR100数据集
import os
import clip
import torch
from torchvision.datasets import CIFAR100
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)#下载数据集
# Prepare the inputs
image, class_id = cifar100[3637] #取其中一张图
image_input = preprocess(image).unsqueeze(0).to(device) #图像预处理
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) #文本预处理
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input) #图像编码器编码
text_features = model.encode_text(text_inputs) 文本编码器编码
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True) #图像特征归一化
text_features /= text_features.norm(dim=-1, keepdim=True) #文本特征归一化
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) #图像-文本相似度计算
values, indices = similarity[0].topk(5) #打印相似度top5
# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
输出结果如下:
Top predictions:
snake: 65.31%
turtle: 12.29%
sweet_pepper: 3.83%
lizard: 1.88%
crocodile: 1.75%
整理不易,欢迎一键三连!!!
送你们一条美丽的--分割线--
🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷