1 背景介绍
CLIP(Contrastive Language-Image Pre-training)是由OpenAI开发的一种多模态预训练模型,它能够将图像和文本嵌入到同一个向量空间中,使它们能够相互“理解”。
CLIP能够同时处理两种不同模态的信息——图像和文本。传统的AI模型通常只处理单一类型的输入,比如只理解文字(如GPT-3)或只理解图像(如ResNet)。而CLIP可以同时理解这两者,这使它特别适合于需要将文字和图像关联起来的任务。
CLIP的核心思想是通过对比学习(Contrastive Learning)来训练模型。这种方法的核心思想是通过比较成对的图像和文本,来学习它们之间的相似度或差异度。在CLIP模型中,模型会接收一批图像-文本对作为输入,并尝试将匹配的图像和文本向量在共同的语义空间中拉近,而将不匹配的向量推远。
CLIP模型由两个主要的编码器组成:一个图像编码器和一个文本编码器。图像编码器负责将图像转换为特征向量,常用的架构有卷积神经网络(如ResNet)和Vision Transformer(ViT)。文本编码器则负责将文本转换为特征向量,通常采用基于Transformer的结构。
ground truth会通过提示词prompt变成一句话,然后传入文本编辑器,测试准确度更高。例如ImageNet有1000个类,这些类会通过prompt变成1000句话,然后传入文本编码器变成1000个向量,然后再拿它们跟图像做匹配,看图像向量和哪个文本向量更加匹配
2 部分概念理解
交叉熵损失函数
L ( Y , P ) = − 1 N ∑ n = 1 N ∑ c = 1 C Y n c l o g ( P n c ) L(Y,P)=-\frac{1}{N}\sum_{n=1}^{N}\sum_{c=1}^{C}Y_{nc}log(P_{nc}) L(Y,P)=−N1∑n=1N∑c=1CYnclog(Pnc)
Y表示真实分布的概率,P表示预测分布的概率,n表示样本个数,c表示分类个数,Ync表示真实的groudtruth的概率,第n个样本分为第C个类别的概率,Pnc表示预测的概率,第n个样本分为第C个类别的概率
对比损失函数
l o s s i = c r o s s − e n t r o p y − l o s s ( l o g i t s , l a b e l s , a x i s = 0 ) loss_i = cross-entropy-loss(logits,labels,axis = 0) lossi=cross−entropy−loss(logits,labels,axis=0)
l o s s t = c r o s s − e n t r o p y − l o s s ( l o g i t s , l a b e l s , a x i s = 0 ) loss_t = cross-entropy-loss(logits,labels,axis = 0) losst=cross−entropy−loss(logits,labels,axis=0)
l o s s = ( l o s s i + l o s s t ) / 2 loss = (loss_i+loss_t)/2 loss=(lossi+losst)/2
一个是一张图片对每个样本做交叉熵损失,也就是对相似度矩阵的每一行做交叉熵损失;
一个是一个文本对每个图片做交叉熵损失,也就是矩阵的每一列做交叉熵损失;
把两个损失分量取和然后除2,就得到了总的损失。
vit划分patch原理
vit论文做法为将给定的一堆图片按照给定的大小分成一堆Patches。本文将输入的图片尺寸为(224×224)按照16×16大小的Patch进行划分。其中(224×224)/(16×16)=196,因此我们会得到196个patches。到这里我们可以知道每一个Patches数据的shape为[16, 16, 3]。为了满足Transformer的需求,在这里,对每个Patch进行投影变化,映射到一维向量中。即每一个Patch完成如下转化:[16, 16, 3]->[768],那么这样一来,就将原始的[224, 224, 3]转化为[196, 768]。
cls token原理
在输入Transformer Encoder之前,值得注意的是需要加上[class] token。在原论文中,作者的意思是参考BERT,在上述得到的一堆tokens中插入一个专门用于分类操作的[class] token,这个[class] token是一个可训练的参数,数据格式和其他token保持一致,均为一个向量。以本文为例,其维度大小为[1, 768]。注意的是,这里采取的是Concat操作。即cat cls token [1, 768]与图像pathch [196, 768] -> [197, 768],此时正好变成了二维矩阵。最终将图像patch变成维度是[197, 768],而本文是将cls token放在第一位,后面分类也是通过cls token给出。
3 核心代码思路示例
导入库
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from transformers import BertModel, BertTokenizer
import timm
import numpy as np
选用图像编码器,采用VIT
# 图像编码器 - 使用ViT
class ViT(nn.Module):
def __init__(self, output_dim):
super(ViT, self).__init__()
# 使用来自timm的ViT模型
self.vit = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=output_dim)
def forward(self, x):
return self.vit(x)
选用文本编码器,采用BERT
class TextEncoder(nn.Module):
def __init__(self):
super(TextEncoder, self).__init__()
#下载的模型路径
BERT_LOCAL_PATH = './bert-base-uncased'
self.model = BertModel.from_pretrained(BERT_LOCAL_PATH)
#使用BertTokenizer.from_pretrained方法加载与BERT模型配套的分词器。这个分词器用于将原始文本转换成BERT模型可以理解的格式。加载后的分词器被赋值给当前对象的tokenizer属性。
self.tokenizer = BertTokenizer.from_pretrained(BERT_LOCAL_PATH)
def forward(self, texts):
# 文本通过BERT
#return_tensors='pt'指定返回的是PyTorch张量。padding=True和truncation=True表示如果输入的文本长度不一致,将进行填充或截断,以确保所有文本具有相同的长度,这是BERT模型处理批量数据时的要求。
encoded_input = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
outputs = self.model(**encoded_input)
#[:, 0, :]表示选择每个样本的第一个标记(通常是特殊标记[CLS])的隐藏状态,这个标记的隐藏状态通常用于分类任务。
return outputs.last_hidden_state[:, 0, :]
加载CIFAR10数据集
def load_cifar10_dataset():
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
train_dataset = CIFAR10(root='./cifar10', train=True, download=True, transform=transform)
loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
classes = train_dataset.classes
return loader, classes
构建clip模型
class CLIP(nn.Module):
def __init__(self, image_output_dim, text_output_dim):
super(CLIP, self).__init__()
self.image_encoder = ViT(image_output_dim)
self.text_encoder = TextEncoder()
# 因为图像和文本emb可能维度不同(图像512(假设举例),文本768),所以需要对图像和文本的emb再经过一层以将维度持平
self.W_i = nn.Parameter(torch.randn(image_output_dim, text_output_dim))
self.W_t = nn.Parameter(torch.randn(768, text_output_dim)) # BERT-base的最后隐藏层大小为768
def forward(self, images, texts):
I_f = self.image_encoder(images) # (B,3,224,224) -> (B, 512)
T_f = self.text_encoder(texts) # (B)-> (B, 768)
# 调整维度
I_e = torch.matmul(I_f, self.W_i) # (B, 512)
T_e = torch.matmul(T_f, self.W_t) # (B, 512)
# 计算余弦相似度
logits = torch.matmul(I_e, T_e.T) # (B,B)
return logits
主函数
def main():
# 加载数据集
dataset, classes = load_cifar10_dataset()
clip_model = CLIP(image_output_dim=512, text_output_dim=512)
for images, labels in dataset:
# 获取一个小批量的图像和标签
texts = [classes[label] for label in labels]
logits = clip_model(images, texts) # (B,B)
#我们希望对角线是真实的值,故把位置当做真实标签
labels = torch.arange(logits.shape[0]) # (0,1,2,3)
# 计算损失 loss_i是每一张图像我都要把它判定为正确得文本,而loss_t是每一个文本我都要把它判定为正确得图像
loss_i = torch.nn.CrossEntropyLoss()(logits, labels)
loss_t = torch.nn.CrossEntropyLoss()(logits.T, labels)
loss = (loss_i + loss_t) / 2
# 输出损失
print(loss)
整体函数流程
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from transformers import BertModel, BertTokenizer
import timm
import numpy as np
class ViT(nn.Module):
def __init__(self, output_dim):
super(ViT, self).__init__()
self.vit = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=output_dim)
def forward(self, x):
return self.vit(x)
class TextEncoder(nn.Module):
def __init__(self):
super(TextEncoder, self).__init__()
BERT_LOCAL_PATH = './bert-base-uncased'
self.model = BertModel.from_pretrained(BERT_LOCAL_PATH)
self.tokenizer = BertTokenizer.from_pretrained(BERT_LOCAL_PATH)
def forward(self, texts):
encoded_input = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
outputs = self.model(**encoded_input)
return outputs.last_hidden_state[:, 0, :]
def load_cifar10_dataset():
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
train_dataset = CIFAR10(root='./cifar10', train=True, download=True, transform=transform)
loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
classes = train_dataset.classes
return loader, classes
class CLIP(nn.Module):
def __init__(self, image_output_dim, text_output_dim):
super(CLIP, self).__init__()
self.image_encoder = ViT(image_output_dim)
self.text_encoder = TextEncoder()
self.W_i = nn.Parameter(torch.randn(image_output_dim, text_output_dim))
self.W_t = nn.Parameter(torch.randn(768, text_output_dim))
def forward(self, images, texts):
I_f = self.image_encoder(images)
T_f = self.text_encoder(texts)
I_e = torch.matmul(I_f, self.W_i)
T_e = torch.matmul(T_f, self.W_t)
logits = torch.matmul(I_e, T_e.T)
return logits
def main():
dataset, classes = load_cifar10_dataset()
clip_model = CLIP(image_output_dim=512, text_output_dim=512)
for images, labels in dataset:
texts = [classes[label] for label in labels]
logits = clip_model(images, texts) # (B,B)
labels = torch.arange(logits.shape[0]) # (0,1,2,3)
loss_i = torch.nn.CrossEntropyLoss()(logits, labels)
loss_t = torch.nn.CrossEntropyLoss()(logits.T, labels)
loss = (loss_i + loss_t) / 2
print(loss)
if __name__ == "__main__":
main()
4 损失函数计算示例
为方便理解,我们举一个计算CLIP损失函数的例子。为了简化说明,我们将使用一个小的批次大小(例如B=2),并假设已经有了图像和文本的嵌入向量。
例子设置
- 批次大小(B)= 2
- 图像嵌入向量:I1, I2
- 文本嵌入向量:T1, T2
- 正确的匹配是:I1与T1,I2与T2
步骤1:计算相似度
首先,我们计算图像嵌入向量和文本嵌入向量之间的相似度。在这个例子中,我们使用点积作为相似度度量。
T1 | T2 | |
---|---|---|
I1 | I1·T1 | I1·T2 |
I2 | I2·T1 | I2·T2 |
假设计算得到的相似度矩阵如下:
T1 | T2 | |
---|---|---|
I1 | 3.0 | 0.5 |
I2 | 0.2 | 2.8 |
步骤2:应用softmax函数
接下来,我们对相似度矩阵的每一行和每一列应用softmax函数,以将相似度分数转换为概率分布。
图像到文本的softmax
- 对于I1:softmax([3.0, 0.5])
- 对于I2:softmax([0.2, 2.8])
假设得到的softmax概率为:
T1 | T2 | |
---|---|---|
I1 | 0.95 | 0.05 |
I2 | 0.05 | 0.95 |
文本到图像的softmax
- 对于T1:softmax([3.0, 0.2])
- 对于T2:softmax([0.5, 2.8])
假设得到的softmax概率为:
I1 | I2 | |
---|---|---|
T1 | 0.95 | 0.05 |
T2 | 0.07 | 0.93 |
步骤3:计算交叉熵损失
现在,我们计算图像到文本的损失和文本到图像的损失。
图像到文本的损失
- 对于I1,正确的匹配是T1,所以目标概率是[1, 0]。计算交叉熵损失:
L i 2 t 1 = − l o g ( 0.95 ) L_{i2}t_{1} = -log(0.95) Li2t1=−log(0.95)
- 对于I2,正确的匹配是T2,所以目标概率是[0, 1]。计算交叉熵损失:
L i 2 t 2 = − l o g ( 0.95 ) L_{i2}t_{2} = -log(0.95) Li2t2=−log(0.95)
图像到文本的总损失是这两个损失的平均值:
L i 2 t = ( L i 2 t 1 + L i 2 t 2 ) / 2 L_{i2}t = (L_{i2}t_{1} + L_{i2}t_{2}) / 2 Li2t=(Li2t1+Li2t2)/2
文本到图像的损失
- 对于T1,正确的匹配是I1,所以目标概率是[1, 0]。计算交叉熵损失:
L t 2 i 1 = − l o g ( 0.95 ) L_{t2}i_{1} = -log(0.95) Lt2i1=−log(0.95)
- 对于T2,正确的匹配是I2,所以目标概率是[0, 1]。计算交叉熵损失:
L t 2 i 2 = − l o g ( 0.93 ) L_{t2}i_{2} = -log(0.93) Lt2i2=−log(0.93)
文本到图像的总损失是这两个损失的平均值:
L t 2 i = ( L t 2 i 1 + L t 2 i 2 ) / 2 L_{t2}i = (L_{t2}i_{1} + L_{t2}i_{2}) / 2 Lt2i=(Lt2i1+Lt2i2)/2
步骤4:计算总损失
最后,我们将图像到文本的损失和文本到图像的损失相加,然后取平均值作为最终的总损失:
L c l i p = ( L i 2 t + L t 2 i ) / 2 L_{clip} = (L_{i2}t + L_{t2}i) / 2 Lclip=(Li2t+Lt2i)/2
最终结果
将上述损失值代入公式,即可得到CLIP模型在这个小批次上的总损失。请注意,这里的损失值是基于假设的相似度分数和softmax概率计算得出的,实际训练过程中这些值会根据模型的参数和输入数据动态变化。