当网络训练后,就可以导入模型进行测试。
def single_img_predict(model_path,img_path):
# 加载网络
model = torch.load(model_path, map_location=lambda storage, loc: storage)
# print(model)
transform = transforms.Compose([
# 这里只对其中的一个通道进行归一化的操作
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.5529,0.55,0.5529], [0.5529,0.50,0.50])])
"""If the unread image is RGBA four-channel the A channel is a transparent channel, and the channel value is temporarily
not used for deep learning model training,so use convert (‘RGB’) for channel conversion"""
img = Image.open(img_path)
RGB_img = img.convert('RGB')
img_torch = transform(RGB_img)
img_torch = img_torch.view(-1, 3, 224, 224)
# 开始预测
outputs = model(img_torch)
_, predicted = torch.max(outputs.data, 1) #即为索引
print(predicted)
print('正在全力预测中.......')
print("上传的图片的预测结果是:" +names[predicted[0].numpy()])
cv2.namedWindow('selected_jpg', cv2.WINDOW_FREERATIO)
cv2.imshow("selected_jpg", cv2.imread(img_path,1)) #imshow图片显示函数 #imread图片读入函数
cv2.waitKey(0) # 键盘绑定函数,参数一般写为0,这样会无限等待键盘的输入,没有返回值。
cv2.destroyAllWindows() #是一个可以轻易删除任何我们建立的窗口
model_path='C:\\Users\\1\\net.pkl'
img_path='C:\\Users\\1\\yilaguangray.jpg'
single_img_predict(model_path,img_path)
这里用到了opencv打开一个窗口来读取图片。首先先将图片进行预处理,转为RGB彩图,再利用导入训练好的模型进行预测,这里print(predicted)就是最大值对应的索引,通过列表即可找到对应的标签,opencv的作用已经在代码中注释啦。结果如下图: