参考链接:
为了能够结合框标注:可以参考我之前的博客,并将其融合到上述的代码中。
一、CLIP 模型预测
可以参考自己的模型,将几个想要显示的样例,用一个前馈模型,得到最终预测的 topk 答案存储。下一步就是将结果进行可视化展示。
from torchvision.datasets import CIFAR100
cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
二、CLIP 模型预测进行可视化图表展示
既然为了画图,就直接用原数据。上面 公式是利用Clip 模型进行预测tokp3的信息,然后绘制图
图片:
left1.png bed 0.14 pickup_truck: 0.7 bridge: 0.6
left2.png racket 0.9 skyscraper: 0.05 bus: 0.05
right1.png cattle 0.4 cammel: 0.15 plain: 0.15
right2.png boy 0.25 man: 0.25 plain: 0.05
1. 参照数据和图,画默认图示
下面不用预测直接用他的图片和显示的数据构造一个显示的程序。下面是网站的程序,我们看看怎么改成上面的数据进行静态展示。
plt.figure(figsize=(16, 16))
for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
plt.subplots_adjust(wspace=0.5)
plt.show()
改成:
plt.figure(figsize=(16, 16))
#bed 0.14 pickup_truck: 0.7 bridge: 0.6
#racket 0.9 skyscraper: 0.05 bus: 0.05
#cattle 0.4 cammel: 0.15 plain: 0.15
#boy 0.25 man: 0.25 plain: 0.05
original_images = []
img_name=["left1.png","right1.png","left2.png","right2.png"]
text_name=[[],[],[],[]]
for i in range(4):
image = Image.open(img_name[i]).convert("RGB")
original_images.append(image)
label_name = ['bed','pickup_truck','bridge','racket','skyscraper','bus','cattle','cammel','plain','boy','man']
top_probs=[[0.14,0.7,0.6],[0.9,0.05,0.05],[0.4,0.15,0.15],[0.25,0.25,0.05]]
top_labels= [[0,1,2],[3,4,5],[6,7,8],[9,10,8]]
for i, image in enumerate(original_images):
plt.subplot(2, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(2, 4, 2 * i + 2)
y = np.arange(3)
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [label_name[index] for index in top_labels[i]])
plt.xlabel("probability")
plt.subplots_adjust(wspace=0.5)
plt.show()
2. 调整到适合大小
将fig 大小由(16,16)变为(16,8)就正常显示 。fig 这里指的是整个大图的大小,大图由多个子图组成。fig(宽,长)
3. 改改柱状图的颜色
4.私人定制指定柱子的颜色