CLIP模型使用(ViT+BERT)
这个ViT+BERT只说明最简单的CLIP文本输出格式,它的输出有什么,之后的文章将会引出两个模态的 embeds 从何而来,它是如何从模型的输出特征中提取出来,为此模型实际上又做了什么改变
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from IPython import embed
from transformers import CLIPTokenizer
import torch
model_name = "/public_bme/data/breast-10-12/CausalFromText/clip-vit-b-16/"
model = CLIPModel.from_pretrained(model_name)
tokenizer = CLIPTokenizer.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
texts = ["This is a description of the first image.", "New as a word of the good fuck."]
images = [Image.open("CLIP.png"), Image.open("CLIP.png")]
inputs_text = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = processor(texts=texts, images=images, return_tensors="pt", padding=True)
inputs.update(inputs_text)
model=model.to("cuda")
inputs=inputs.to("cuda")
with torch.no_grad():
outputs = model(**inputs)
但是这个outputs其实是一个字典
In [1]: outputs.keys()
Out[1]: odict_keys(['logits_per_image', 'logits_per_text', 'text_embeds', 'image_embeds', 'text_model_output', 'vision_model_output'])
-
那它们分别都是什么
- 1. (logits_per_image), (logits_per_text)
- 2. (text_embeds), (image_embeds)
- 3. (text_model_output), (vision_model_output)
1. 这里的logits是通过embeds计算而来的相似度,用来做匹配的
如果将image和text其中一个进行转置,就会发现它是相等的
In [6]: outputs.logits_per_image
Out[6]:
tensor([[24.3546, 17.5897],
[24.3546, 17.5897]], device='cuda:0')
In [7]: outputs.logits_per_text
Out[7]:
tensor([[24.3546, 24.3546],
[17.5897, 17.5897]], device='cuda:0')
In [9]: outputs.logits_per_text==outputs.logits_per_image.t()
Out[9]:
tensor([[True, True],
[True, True]], device='cuda:0')
2. 这里的 embeds 是通过 outputs 里面的 last_hidden_state 计算而来,,用来做匹配的
In [14]: outputs.text_embeds.shape
Out[14]: torch.Size([2, 512])
In [26]: outputs.image_embeds.shape
Out[26]: torch.Size([2, 512])
3. 这里的 last_hidden_state 是两个模态 model 的输出
视觉编码器的输出是这样的,你可以很容易通过这样的理解,迁移使用到resnet等其他网络
In [20]: outputs.vision_model_output.last_hidden_state.shape
Out[20]: torch.Size([2, 197, 768])
In [21]: outputs.text_model_output.last_hidden_state.shape
Out[21]: torch.Size([2, 12, 512])
模型的训练
我们看到CLIP的输出的logits是一个与batch相关的对角矩阵,这个矩阵实际上就是下面图片的逻辑值
通过阅读论文,我们注意到了CLIP是如何计算损失的,原文如下:
实际上的python程序是这样的
logits_per_image, logits_per_text = model(inputs)
labels = torch.arange(batch_size).to("cuda")
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)
loss = (loss_i + loss_t) / 2