利用CLIP的API来识别验证CLIP在CIFAR10/100的零样本识别效果。
Learning Transferable Visual Models From Natural Language Supervision
论文链接:https://arxiv.org/pdf/2103.00020
1、简述
OpenAI 2021年发布的视觉-语言自监督大模型,为后续各种视觉、语言任务的大模型开发奠定基础。
论文摘要:
最新技术的计算机视觉系统被训练来预测一组固定的预先确定的物体类别。这种受限的监督形式限制了它们的普适性和可用性,因为需要额外的标记数据来指定任何其他视觉概念。直接从关于图像的原始文本中学习是一个有前途的替代方法,它利用了一个更广泛的监督来源。文中展示了,预测哪个标题适合哪个图像这个简单的预训练任务是一种有效且可扩展的方式,可以从互联网收集的4亿(图像,文本)对数据集上从零开始学习最先进的图像表示。预训练后,自然语言用于引用学习的视觉概念(或描述新的概念),从而实现将模型零迁移到下游任务。
通过在超过30个不同的现有计算机视觉数据集上进行基准测试来研究这种方法的性能,覆盖了OCR、视频中的动作识别、地理定位以及许多类型的细粒度对象分类等任务。该模型在大多数任务上具备非平凡的迁移能力,并且通常与完全监督的基准线相竞争,而无需进行任何特定数据集的训练。例如,在ImageNet的零迁移上与原始的ResNet-50的准确率相匹配,而无需使用其训练过的1.28百万个训练样本之一。
2、使用官方CLIP进行CIFAR10/100零样本识别
首先,安装pytorch,clip和其他依赖包,具体安装方式可以在终端输入命令,如下:
conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
说明:如果已经有安装过pytorch的环境就只需要安装clip和依赖的其他包就可以了。
官方提供了['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
等多个主干网络,可以根据自身电脑配置选择一个模型,这里以"RN50"和”ViT-B/32“进行示例。
接下来,先上完整的”ViT-B/32“调用进行CIFAR100零样本识别任务,具体代码如下:
import torch
import torchvision
import clip
from tqdm import tqdm
# set random seed
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def accuracy(output, target, topk=(1, )):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
def zeroshot_classifier(classes_names, templates, device):
with torch.no_grad():
zeroshot_weights = []
for classes_name in tqdm(classes_names):
texts = [template.format(classes_name) for template in templates]
texts = clip.tokenize(texts).to(device)
class_embeddings = model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) # normalize
class_embeddings = class_embeddings.mean(dim=0)
class_embeddings /= class_embeddings.norm()
zeroshot_weights.append(class_embeddings)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
print(clip.available_models())
# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device}")
# set clip model
model, preprocess = clip.load("ViT-B/32", download_root="./ckpts", device=device)
# model, preprocess = clip.load("RN50", download_root="./ckpts", device=device)
# set dataset and dataloader
# CIFAR10: torchvision.datasets.CIFAR10
# CIFAR100: torchvision.datasets.CIFAR100
train_dataset = torchvision.datasets.CIFAR100(root='./data/',
train=True,
download=True,
transform=preprocess)
test_dataset = torchvision.datasets.CIFAR100(root='./data/',
train=False,
download=True,
transform=preprocess)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=32,
shuffle=True,
num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=32,
shuffle=False,
num_workers=0)
classes = train_dataset.classes
print(classes)
# set language template
cifar_template = ["a photo of a {}."] # 1 这里所有图片都用一个template
zeroshot_weights = zeroshot_classifier(classes, cifar_template, device)
# test
with torch.no_grad():
top1, top5, n = 0., 0., 0.
for i, (images, target) in enumerate(tqdm(test_loader)):
images = images.to(device)
target = target.to(device)
# predict
image_feat = model.encode_image(images)
image_feat /= image_feat.norm(dim=-1, keepdim=True)
logits = 100. * image_feat @ zeroshot_weights
# measure accuracy
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
top1 += acc1
top5 += acc5
n += images.size(0)
# calculate and plot results
top1 = (top1 / n) * 100
top5 = (top5 / n) * 100
print(f"Top1 accuracy: {top1:.2f}, Top5 accuracy: {top5:.2f}")
3、对template的消融实验
对template的质量和数量进行消融实验,主要将template从一个增加到10个,28个,同时改变template的质量(表达)来测试模型识别性能。
cifar_template = [
'a photo of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'a cropped photo of the {}.',
'a photo of a hard to see {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'the plastic {}.',
'a close-up photo of a {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a cropped photo of a {}.',
'a photo of one {}.',
'a photo of a {}.',
'a low resolution photo of a {}.',
'a blurry photo of a {}.',
'art of a {}.',
'a sketch of the {}.',
'a pixelated photo of a {}.',
'a jpeg corrupted photo of the {}.',
'a photo of the nice {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of a small {}.',
]
CIFAR100实验结果如下:
Model | #template | Top-1 | Top-5 |
---|---|---|---|
ViT-B/32 | 1 | 62.24 | 86.99 |
RN50 | 1 | 40.59 | 72.86 |
ViT-B/32 | 10 | 63.23/64.33* | 87.28/88.64* |
RN50 | 10 | 40.30/40.77* | 71.23/72.22* |
ViT-B/32 | 24 | 64.16 | 88.19 |
RN50 | 24 | 41.22 | 72.62 |
ViT-B/32 | 28 | 64.18 | 88.26 |
RN50 | 28 | 41.19 | 72.63 |
说明:”*“代表使用相同数量但与前一个结果不同的template。
从实验结果可以得出以下结论:
-
(1)不同类型的网络结构,transformer比cnn模型性能更好,这主要是因为transformer的自注意力和交叉注意力机制,提升了网络对全局、局部特征的表达能力。
-
(2)不同数量的template,数量较大的template普遍优于数量为1的template,同时template数量并不是越大越好,还跟template的质量和表达有关,这说明template的质量和数量在一定程度上对模型性能有较大的影响。
-
(3)当template数量逐渐增加时,模型性能更好,因为样本对应的文字表达更加丰富,特征间的判别性越明显,但当template达到一定数量时,模型性能达到最高值,并趋于稳定,这主要跟模型的表达能力有关。