CLIP
此文章为CLIP模型github网页中的colab代码解读,包含CLIP的主要思想。
CLIP保存预训练模型
clip.available_models()
['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
数据预处理
图片——resize输入图片,并进行裁剪,totensor,归一化
Compose(
Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
CenterCrop(size=(224, 224))
<function _convert_image_to_rgb at 0x7c4f5656cc10>
ToTensor()
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
文字——编码,长度为77
clip.tokenize(“Hello World!”)
tensor([[49406, 3306, 1002, 256, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)
输入数据集
图像文本对——一个图片有着对应的文本描述
网络架构
图片和文本编码
image_input = torch.tensor(np.stack(images)).cuda() # stack把原本的list转变为矩阵
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda() # 形成一个句子
image_input——[8, 3,244,244]
text_tokens——[8, 77]
预训练编码器编码
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
两个特征的shape均为[8, 512]
计算相似度
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
归一化:将每一行转化为单位向量
将两个矩阵相乘计算相似度[8, 512] @ @ @[512, 8]->[8,8]
CLIP Zero-Shot图片分类
数据——cifar100
进行100分类,text_tokens——[100,77]
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
图片分类
计算100分类的text_features,与原先的image_features计算相似度
其中,100是一个缩放因子,若将100除去,我们发现得到的相似度分数比较相近,无法很好将正样本与负样本隔开。在CLIP源代码中其为一个网络参数。
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) #100是一个缩放因子
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)# 选出前五