import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
trans = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img = Image.open('testsets/set5/butterfly.bmp')
img = trans(img)
# 加一个batch维度
img = torch.unsqueeze(img, dim=0)
model = torchvision.models.vgg19(pretrained=True)
model.eval()
with torch.no_grad():
output = torch.squeeze(model(img))
predict = torch.softmax(output, dim=0) # 得到概率分布
predict_cla = torch.argmax(predict).numpy() # 获取概率最大处所对应的索引值
'''
获取前几个可能
'''
def get_max(n, pre):
#得到从大到小排序的索引号
pre = np.argsort(-pre)
# 读取索引对应
with open('imagenet1000_clsid_to_human.txt','r') as f:
line = f.readlines()
name = []
for i in range(n):
print(pre[i])
name.append(line[int(pre[i])].split('\'')[1])
return name
print(get_max(5, predict))
结果:
文件 imagenet1000_clsid_to_human.txt 下载: