图像分类----图片测试
代码流程
使用训练好的图像分类模型,对新图像文件进行预测。
1.引入库
代码如下(示例):
import torch
import numpy as np
import torchvision
import pandas as pd
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import transforms
2.预测单张新图片类别
代码如下(示例):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = torch.load('训练完成的模型.pth')
model = model.eval().to(device)
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img_path = '图片.jpg'
input_img = test_transform(img_pil)
input_img = input_img.unsqueeze(0).to(device)
pred_logits = model(input_img)
n = 10
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
draw = ImageDraw.Draw(img_pil)
for i in range(n):
class_name = idx_to_labels[pred_ids[i]] # 获取类别名称
confidence = confs[i] * 100 # 获取置信度
text = '{:<15} {:>.4f}'.format(class_name, confidence)
print(text)
# 文字坐标,中文字符串,字体,rgba颜色
draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))