用pytorch构建Alexnet模型(predict模块)(个人笔记)

将原代码修改了一下,可以将所需要预测的图片放在指定文件夹内(可放多个),代码会依次预测该图片属于哪个类,并将其保存在TXT文件中(run栏不显示预测结果不要感到奇怪,打开当前目录下生成的txt文件即可看到结果)。

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(    # 定义图片预处理函数,用来对载入图片进行预处理操作
        [transforms.Resize((224, 224)),   # 缩放到224*224
         transforms.ToTensor(),   # 转化为一个tensor
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])    # 标准化处理
    txtfilepath = "F:\PyTorch\picture"   # 原始txt文件所存文件夹,文件夹可以有一个或多个txt文件
    total_txt = os.listdir(txtfilepath)   # 返回指定的文件夹包含的文件或文件夹的名字的列表
    num = len(total_txt)
    list = range(num)  # 创建从0到num的整数列表

    for i in list:
        name = total_txt[i]
        # load image
        img = Image.open(txtfilepath+"/"+name, 'r') #读取文件

        plt.imshow(img)  # 展示输入的图片
        # [N, C, H, W]
        img = data_transform(img)   # 调用预处理函数,对载入读片进行预处理
        # expand batch dimension
        img = torch.unsqueeze(img, dim=0)  # 预处理之后扩充一个维度(batch维度),这与Alexnet输入有关(具体见NB笔记)

        # read class_indict
        json_path = './class_indices.json'   # 读取保存的json文件(类别名称以及对应的索引)
        assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

        json_file = open(json_path, "r")  # 解码成所需要的字典
        class_indict = json.load(json_file)

        # create model
        model = AlexNet(num_classes=5).to(device)   # 初始化网络

        # load model weights
        weights_path = "./AlexNet.pth"
        assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
        model.load_state_dict(torch.load(weights_path))  # 载入网络模型

        model.eval()   # 进入eval模式(即关闭掉droout方法)
        with torch.no_grad():   # 不跟踪变量的损失梯度
            # predict class
            output = torch.squeeze(model(img.to(device))).cpu()   # 将图片通过model正向传播,得到输出,将输入进行压缩,将batch维度压缩掉,得到最终输出(out)
            predict = torch.softmax(output, dim=0)  # 经过softmax处理后,就变成概率分布的形式了
            predict_cla = torch.argmax(predict).numpy()  # 通过argmax方法,得到概率最大的处所对应的索引


        print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())   # 打印类别名称以及他所对应的预测概率

        plt.title(print_res)
        for i in range(len(predict)):
            with open('test.txt', 'a') as file0:  # 将以下print内容保存到test.txt文件中
                print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                      predict[i].numpy()), file=file0)
        with open('test.txt', 'a') as file0:
            print("--------------------------我是可爱的分隔线--------------------------", file=file0)
        # plt.show()


if __name__ == '__main__':
    main()

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值