pytorch载入训练好的模型并进行可视化模型预测绘图

main函数载入模型,加载图片,输出结果:

if __name__ == '__main__':
  image =  Image.open(r"C:\Users\pic\test\he_5.jpg")
    image =transform(image).unsqueeze(0)
    modelme = torch.load('modefresnet.pkl')
    modelme.eval() #表示将模型转变为evaluation(测试)模式,这样就可以排除BN和Dropout对测试的干扰。
     visualize_model(modelme)
    outputs = modelme(image)
    _, predict = torch.max(outputs.data, 1)
        for j in range(image.size()[0]):
     print('predicted: {}'.format(class_names[predict[j]]))

对图片的统一处理transform:

transform=transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
                            ])

对于预测结果进行可视化的函数:

def visualize_model(model, num_images=6):
 was_training = model.training
 model.eval()
 images_so_far = 0
 fig = plt.figure()
 with torch.no_grad():

     #for i, (inputs, labels) in enumerate(dataloaders['val']):
     for i, (inputs, labels) in enumerate(testloder):
       outputs = model(inputs)
       _, preds = torch.max(outputs, 1)
       
       for j in range(inputs.size()[0]):

           images_so_far += 1

           ax = plt.subplot(num_images // 2, 2, images_so_far)

           ax.axis('off')

           ax.set_title('predicted: {}'.format(class_names[preds[j]]))

           imshow(inputs.cpu().data[j])


           if images_so_far == num_images:

            model.train(mode=was_training)
            plt.show()

            return

     model.train(mode=was_training)

载入一新的图片数据集:

data_dir =os.getcwd() + '\\data\\'
dataloadertest =datasets.ImageFolder(os.path.join(data_dir, "tt"),transform=transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
                            ]) )
testloder = torch.utils.data.DataLoader(dataloadertest,batch_size = 4,shuffle = True)

目录结构:
在这里插入图片描述
其中要注意传入的图片的预处理:
image = Image.open(r"C:\Users\pic\test\he_5.jpg")
image =transform(image).unsqueeze(0)
需为PIL格式,且需先进行转化才能传入模型。

结果:
在这里插入图片描述

在这里插入图片描述
经测试之后不论是传入单张图片还是一个新数据集结果均符合预期。

已标记关键词 清除标记
相关推荐
<p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <span style="color:#E53333;"><strong>【课程介绍】</strong></span>  </p> <p style="text-align:left;">      Pytorch项目实战 垃圾分类 课程从实战的角度出发,基于真实数据集与实际业务需求,结合当下最新话题-垃圾分类问题为实际业务出发点,介绍最前沿的深度学习解决方案。 </p> <p style="text-align:left;">     从0到1讲解如何场景业务分析、进行数据处理,模型训练与调优,最后进行测试与结果展示分析。全程实战操作,以最接地气的方式详解每一步流程与解决方案。 </p> <p style="text-align:left;">     课程结合当下深度学习热门领域,尤其是基于facebook 开源分类神器ResNext101网络架构,对网络架构进行调整,以计算机视觉为核心讲解各大网络的应用于实战方法,适合快速入门与进阶提升。 </p> <p style="text-align:left;"> <strong><span style="color:#E53333;">【课程要求】</span></strong> </p> <p style="text-align:left;"> (1)开发环境:python版本:Python3.7+;<span style="color:#E53333;"> torch 版本:1.2.0+; torchvision版本:0.4.0+</span> </p> <p style="text-align:left;"> (2)开发工具:Pycharm; </p> <p style="text-align:left;"> (3)学员基础:需要一定的Python基础,及深度学习基础; </p> <p style="text-align:left;"> (4)学员收货:掌握最新科技图像分类关键技术; </p> <p style="text-align:left;"> (5)学员资料:内含完整程序源码和数据集; </p> <p style="text-align:left;"> (6)课程亮点:专题技术,完整案例,全程实战操作,徒手撸代码 </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <span style="color:#E53333;"><strong>【课程特色】</strong></span> </p> 阵容强大 <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> 讲师一直从事与一线项目开发,高级算法专家,一直从事于图像、NLP、个性化推荐系统热门技术领域。 </p> <p style="text-align:left;"> 仅跟前沿 </p> <p style="text-align:left;"> 基于当前热门讨论话题:垃圾分类,课程采用学术届和工业届最新前沿技术知识要点。 </p> <p style="text-align:left;"> 实战为先 </p> <p style="text-align:left;"> 根据实际深度学习工业场景-垃圾分类,从产品需求、产品设计和方案设计、产品技术功能实现、模型上线部署。精心设计工业实战项目 </p> <p style="text-align:left;"> 保障效果 </p> <p style="text-align:left;"> 项目实战方向包含了学术届和工业届最前沿技术要点 </p> <p style="text-align:left;"> 项目包装简历优化 </p> <p style="text-align:left;"> 课程内垃圾分类图像实战项目完成后可以直接优化到简历中 </p> <p style="text-align:left;"> <strong><span style="color:#E53333;">【课程思维导图】</span></strong> </p> <p style="text-align:left;"> <img src="https://img-bss.csdn.net/201912081323318969.png" alt="" /> </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <strong><span style="color:#E53333;">【课程实战案例】</span></strong> </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <img src="https://img-bss.csdn.net/201912081326184463.png" alt="" /> </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <br /> </p>
©️2020 CSDN 皮肤主题: 鲸 设计师:meimeiellie 返回首页