Pytorch搭建AlexNet 预测实现

1.导包

import torch
import matplotlib.pyplot as plt
import json
from model import AlexNet
from PIL import Image
from torchvision import transforms

2.数据预处理

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),  # 将图片重新裁剪
     transforms.ToTensor(),  # 转化为tensor
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 标准化数据

 3.加载测试图片

# load image
img = Image.open("1.jpeg")  # 网上随便下载,放到好找的路径下
plt.imshow(img)   # 直接载入图像
img = data_transform(img)  在预处理过程中吧channel提到前面
img = torch.unsqueeze(img, dim=0)  # 添加batch维度

4.读取分类文件

# read class_indent
try:
    # 读取保存在json文件中索引对应的类别名称
    json_file = open('./class_indices,json', 'r')
    class_indict = json.load(json_file)  # 将json文件解码成字典格式
except Exception as e:
    print(e)
    exit(-1)

5.初始化网络

output = torch.squeeze(model(img)):先将图片通过正向传播得到输出,再把输出的batch压缩

predict = torch.softmax(output, dim=0):通过softmax得到一个概率分布

predict_cla = torch.argmax(predict).numpy():找到概率最大处所对应的索引值

print将类别名称和预测概率输出

# create model
model = AlexNet(num_classes=5)
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))  # 载入网络模型
model.eval()  # 关闭dropout
with torch.no_grad():
    output = torch.squeeze(model(img))
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

 6.预测结果

容易把玫瑰识别成郁金香,把蒲公英识别成向日葵,郁金香,向日葵,小雏菊可以很好的识别出来,模型的准确率还是有点低。大家自己尝试测试一下吧哈哈。

 PyTorch搭建AlexNet网络合集:
PyTorch搭建AlexNet网络模型-CSDN博客

PyTorch搭建AlexNet训练集-CSDN博客

Pytorch搭建AlexNet 预测实现-CSDN博客

  • 8
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值