ResNet残差网络Pytorch实现——对花的种类进行单数据预测

ResNet残差网络Pytorch实现——对花的种类进行单数据预测


上一篇:【对花的种类进行训练】 ✌✌✌✌ 【目录】 ✌✌✌✌ 下一篇:【对花的种类进行批量数据预测】


大学生一枚,最近在学习神经网络,写这篇文章只是记录自己的学习历程,本文参考了Github上fengdu78老师的文章进行学习


✌ 使用ResNet进行对花的种类进行单数据预测

import os
import json

import torch
from torchvision import transforms

from PIL import Image

# 加载运算设备
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 数据处理
data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# 预测图片的路径
img_path='13290033_ebd7c7abba_n.jpg'

# 加载图片
img=Image.open(img_path)

# 将图片进行处理,返回的是tensor
img=data_transform(img)

# 将数据升维,图片是三维,而训练时是4维,因为训练时第一个维度为每个批次的训练数据大小
# 本次预测一个图片,就要升维,变成1,表明该批次图片有1个
img=torch.unsqueeze(img,dim=0)

# 读取预测结果和真实分类的映射
json_path='./class_indices.json'

json_file=open(json_path,'r')

# 加载成字典 0:'A',1:'B',2:'C'
class_indict=json.load(json_file)

# 创建网络
model=resnet34(num_classes=5).to(device)

# 加载模型训练好的参数
weitcht_path='./resNet34.pth'
model.load_state_dict(torch.load(weitcht_path,map_location=device))

# 开启验证模式
model.eval()
# 不需要求导
with torch.no_grad():
    # 每个数据对应输出,(1,5)维度,将其降维,直接是(5,)
    output=torch.squeeze(model(img.to(device))).cpu()
    # 如果是预测,完全可以用touch.max,返回最大值索引,但是下面为了输出预测的概率,就要将其标准化,概率和为1
    # 按道理说应该是dim=1,每行的所有列,但是现在不是二维,所以要用0
    predict=torch.softmax(output,dim=0)
    # 返回最大值索引,dim=0和上面的同理
    predict_cla=torch.argmax(predict,dim=0).item()

print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
print(print_res)
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

海洋 之心

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值