【记录】pth-->onnx-->om华为Atlas200DK深度学习Pytorch模型部署初体验

〇、官方参考文档

正文开始之前,附几个官方链接,以官方教程为主,下文为辅进行模型部署。

昇腾社区-官网丨昇腾万里 让智能无所不及   查阅文档内容

资源-Atlas 200I DK A2-昇腾社区      翻页-->开发课程-->开发者课程-->图片分类应用开发入门教程

昇腾论坛

一、引言

由于本科毕设要求,需要将pytorch图像识别(10分类)模型进行嵌入式端部署。本文假设已经获得训练好的.pth文件(权重+模型),介绍从pth文件到onnx文件再到om模型的转换以及python/C++的推理过程。

 二、软硬件准备

  1. Atlas 200I DK A2 开发者套件
  2. windows 10或ubuntu系统(用以模型转换)
  3. .pth模型
  4. Python3.11+PyTorch环境

三、pth-->onnx模型转换

1、python代码转换

import torchvision
import onnx
import torch
from vgg16 import Vgg16Classifier
from resnet18 import ResNet18Classifier 
        #如果保存的pth文件是权重文件不包含模型框架,
        #需要在代码中加入模型类的定义或者通过from的方式import
num_classes = 10 #模型分类输出数
DIR_STATE_DIC_PTH = './model/resnet18_offical.pth' #pth模型地址
ONNX_MODEL_PATH = './onnx/' #onnx模型保存地址
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 将state_dic_pth权重转换成完整的pth文件
def get_full_pth(type, DIR_STATE_DIC_PTH):
    if type=='vgg16':
        model = Vgg16Classifier(1, num_classes)

    elif type=='resnet18':
        model = ResNet18Classifier(1, num_classes) #(通道数,输出类别数)
        # 此处根据自己模型定义设置,相当于实例化一个类对象
    model.load_state_dict(torch.load(DIR_STATE_DIC_PTH))  # weight=torch.load(DIR_STATE_DIC_PTH)
    #将权重加载到实例化的类对象内,获得完整的pth文件
    print(f'{type}.pth完整模型已获得!')
    return model
def pth2onnx(type):
    model = get_full_pth(type, DIR_STATE_DIC_PTH)
    model.to(device)
    model.eval()
    dummy_img = torch.Tensor(torch.randn((1, 1, 224, 224)).cuda()) # (batchsize, channels, width, height)
    torch.onnx.export(
        model = model, #pth模型
        args=dummy_img,
        f=ONNX_MODEL_PATH+type+'.onnx', #onnx保存模型地址
        export_params=True,
        verbose=False,
        input_names=['input'],
        output_names=['output'],
        opset_version=11 #pth转onnx好像都是设置=11
    )
    print(f'已生成{type}.onnx文件!')
# 按间距中的绿色按钮以运行脚本。
if __name__ == '__main__':
    pth2onnx('resnet18')
    

值得注意的是,pth转onnx要保证pth文件既包含训练获得的权重(字典格式),也要包括模型结构本身。

1、方式一

一般训练过程中会在训练过程中间断地保存模型,一般选择保存字典格式

torch.save(ResNet18Model.state_dict(), './Model_train3/ResNet18Model_{}.pth'.format(i + 1))

那么对应的加载方式为:

ResNet18Model = ResNet18Classifier(in_channels=1, num_classes=num_classes)
ResNet18Model.load_state_dict(torch.load('./Model_train2/ResNet18Model_395.pth')) # 字典的方式需要先实例化再load


 2、方式二

也可以选择保存全部,占用内存会比字典格式大一丢丢:

torch.save(ResNet18Model, './Model/ResNet18Model_{}.pth'.format(i+1))

那么对应的加载方式为:

ResNet18Model = torch.load('./Model_train2/ResNet18Model_46.pth') # 导入整个模型的方式不需要实例化模型和导入类

如果使用第二种保存方式,进行pth-->onnx转换时,可以忽略前面代码中提到的get_full_pth()函数,只用pth2onnx()函数即可。因为最重要的代码就是torch.onnx.export()这句代码。

需要注意的是pth-->onnx模型转换时要保证device和pth模型训练时保持一致。

2、onnx模型测试

1、Netron模型结构查看

将pth文件和onnx首先丢进Netron看看模型结构是否一致。

注意:丢进Netron的pth文件要保证既包含权重也包含模型结构。如果保存pth选择方式一,则需要用方式一加载的方式加载,再通过方式二保存的方式保存一下。

左图表示pth未包含模型结构的情况,右图表示完成的pth和onnx的情况。

2、代码推理pth和onnx输出是否一致

代码借鉴了深度学习届扛把子,非常感谢。

import numpy as np
import sys

import cv2
import onnxruntime
import torch
from torchvision import transforms
sys.path.append('..')
from resnet18 import ResidualBlock
from resnet18 import ResNet18Classifier

pth_model_path = '../model/resnet.pth'
img_pth = './images/T62/T62_1.JPG'
onnx_model_path = 'resnet18.onnx'
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(224, antialias=True),  # 抗锯齿
    transforms.Grayscale(1),  # 注意:使用ImageFolder读取的图片会变成RGB三通道图片
    transforms.Normalize(mean=[0], std=[1])
])

img = cv2.imread(img_pth)

img = trans(img)
img = np.expand_dims(img, axis=0)  # 扩展第一维度,适应模型输入
img = torch.tensor(img)
print(img.shape)

resnet18 = ResNet18Classifier(1, 10)
net = resnet18
net.load_state_dict(torch.load(pth_model_path))
# print(net)
net.eval()
output = net(img)
# print(output)
print("pth weights", output.detach().cpu().numpy()) #10分类输出的10个值
print("pth prediction", output.argmax(dim=1)[0].item())


# onnx测试
resnet_session = onnxruntime.InferenceSession(onnx_model_path)
# compute ONNX Runtime output prediction
img = np.array(img)
inputs = {resnet_session.get_inputs()[0].name: img}
outs = resnet_session.run(None, inputs)[0]

print("onnx weights", outs)
print("onnx prediction", outs.argmax(axis=1)[0])

如果两部分输出的值大小几乎一致(因为小数点后几位可能有模型转换导致的精度损失,应该可以忽略不计的)则说明转换成功,否则失败。

这里是我花费时间调试较多的一个地方,问题出在输入数据的格式上面。即图片转化成tensor数据格式的过程,因为要保证此处输入的数据格式和训练过程中的数据格式一致,因此可以参考自己训练pth模型时的代码,我在学习时,一般教程中都是用ImageFloader进行图片读取,而ImageFloader是这样使用的:

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=True),
    transforms.Grayscale(1), # 注意:使用ImageFolder读取的图片会变成RGB三通道图片
    transforms.Normalize(mean=[0], std=[1])
])
train_data = torchvision.datasets.ImageFolder(root='./data/train', transform=trans)

而在后续在开发板上进行om模型推理时,数据的输入格式和测试pth测试onnx模型的输入格式是一致的,因此如果此处仍然选择使用ImageFloader进行图片读取,则在开发板上进行om模型推理时,为了达成数据预处理的要求,可能就得安装torch来进行,这样比较麻烦,而在嵌入式系统安装cv2教程相对比较容易,因此此处建议大家使用cv2进行图片读取,可以使用torch相关的包进行tensor数据格式转换,因为Atlas 200I DK A2 开发者套件内由ToTensor()函数可以替换。

此处用cv2进行数据预处理,会在开发板惊醒om模型推理前数据预处理时少走一些弯路。

四、onnx-->om模型转换 

在此之前,建议首先观看文章开头提到的第二个链接对应的教学视频,首先了解MobaXterm软件进行文件传输、远程登录的方式。

我使用的是开发板通过typec与windows电脑进行连接,根据文章开头提到的第一个链接对应文档进行开发板和电脑网络配置,实现SSH远程登陆,通过命令行安装tigervnc实现远程桌面,另外开发板的上网功能也在文章开头第一个链接中找到答案。

我参考的onnx-->om的文档在   文档-应用开发指南-模型转换  

1、开发者套件进行onnx-->om模型转换 

详见在开发者套件上进行模型转换

将从windows pytorch中转换得到的onnx模型传输到开发板的指定位置。

核心命令为:

atc --model=/root/ImageClassifier/resnet18.onnx --framework=5 --output=/root/ImageClassifier/resnet18 --soc_version=Ascend310B4 
#说明:
#atc --model=ONNX_PATH --framework={onnx转om填5,其他情况见文档} --output=OM_SAVE_PATH --soc_version={通过npi-smi info命令查看版本号,如果是310B则输入Ascend310B}

此处,我测试过程中,转换我自己写的vgg16网络,大概500M,可以在十几分钟内转换成功;转换自己写的resnet18网络,大概40M,转换几个小时后报一串错误;转换文档中提到的例程resnet50可以在10分钟之内成功;转换torchvision.models.resnet18()也报错。通过CSDN以及文章开头提到的第三个链接的昇腾论坛中也未解决。

某次报错截图如下:

错误中提到的关键词为Exception in thread Thread-1、 Failed to compile Op...Conv2D。考虑为算子不支持导致的,或许和开发板中的atc工具有关,暂时未解决。

2、在Ubuntu系统上转换模型(建议)

在ubuntu系统上进行模型转换,相比于开发板有更多的计算资源,成功率较高。这也是昇腾论坛中一些技术人员的建议。

由于我不方便在电脑上装双系统,也没试试VMware虚拟机怎样,看到很多wsl的教程,就试了试,结果成功了,但仍有部分warning未解决,但经过后边的om推理,仍然是成功的。

参考文档:使用WSL安装Linux Ubuntu 22.04

比较麻烦的是需要手动安装CANN。

具体的命令见文档教程,一步一步跟随安装即可。虽然但是,这部分安装花费了我一天的时间。

需要注意的是:

1、WSL的文件系统进入方法是:在windows任意文件地址栏输入:\\wsl$    即可进入。

里面/mnt/文件夹下的c、d、e等文件夹对应windows系统的C、D、E等盘符。WSL需要某个文件时,可以通过在WSL中cd对应位置获取,也可以通过ctrl c v复制粘贴到WSL中某个方便一点的文件夹内。

2、文档中给出的下载地址对应的下载包可能和文档给出的命令对应不一致,比如我点击【下载链接】跳转下载的是Ascend-cann-toolkit_7.0.RC1_linux-x86_64.run,于是chmod +x Ascend-cann-toolkit_6.2.RC2_linux-x86_64.run、./Ascend-cann-toolkit_6.2.RC2_linux-x86_64.run --install两个命令文件名需要对应修改。

3、安装依赖一定要保证依赖下载完整。建议使用非root用户安装,即创建一个个人用户,创建成功应该会在/home/文件夹下有{user}用户名文件夹。不要root和非root混乱安装。可以通过pip3 list来检查文档中依赖要求的包是否都安装上了。

4、安装依赖过程中我遇到pip3 install scipy --user失败的问题,可以去scipy官网下载whl再安装的方式进行。具体下载哪个版本,可以通过pip3 install scipy --user下载失败时打印的信息看到,例如我的就打印了balabala.......scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl........balabala,所以去官网对应找whl就行了。

atc --model=/home/kk/onnx/vgg16.onnx --framework=5 --output=/home/kk/om/vgg16 --soc_version=Ascend310B4 
#说明:
#atc --model=ONNX_PATH --framework={onnx转om填5,其他情况见文档} --output=OM_SAVE_PATH --soc_version={通过npi-smi info命令查看版本号,如果是310B则输入Ascend310B}

这里报了特别多警告,经过查资料和咨询一些前辈,考虑为可能为Ubntu22.04LST的问题,阉割了部分功能,听说在完整版上正常不报警告。

即便报以上警告,经测试转换的om模型结果输出是正常的。

 五、通过MindX SDK(python)进行om模型推理

通过AscendCL(python/C++)的方式还没学会。

参考文章开头第二个链接对应的视频学习视频

视频中提到的文档的下载链接在下载代码链接页面可以找到。将文件放到开发板的某个目录下,cd到该目录后运行:

python3 main.py

如果类似报错:

(base) root@davinci-mini:~/ImageClassifier/resnet18_inference# python main.py
Traceback (most recent call last):
  File "/root/ImageClassifier/resnet18_inference/main.py", line 4, in <module>
    from mindx.sdk import Tensor  # mxVision 中的 Tensor 数据结构
  File "/root/.local/lib/python3.9/site-packages/mindx/__init__.py", line 12, in <module>
    from . import sdk
  File "/root/.local/lib/python3.9/site-packages/mindx/sdk/__init__.py", line 12, in <module>
    from .base import *
ImportError: libglog.so.1: cannot open shared object file: No such file or directory

请先配置环境变量:(运行)

. /usr/local/Ascend/mxVision-5.0.RC3/set_env.sh

然后再执行:

python3 main.py

看看是否和视频中提到的效果相同。如果正常(即便有个warning,因为视频中也看到有warning,没事)运行,把/model/文件夹的om模型替换(其他两个文件用不到,可以删除),/data/文件夹的测试图片替换,/utils/文件夹中.cfg文件和.name对应修改为自己的内容。

最重要的是修改main.py文件。修改之前的官方main.py就不附上了,下面是修改后的代码:

import numpy as np  # 用于对多维数组进行计算
import cv2  # 图片处理三方库,用于对图片进行前后处理

from mindx.sdk import Tensor  # mxVision 中的 Tensor 数据结构
from mindx.sdk import base  # mxVision 推理接口
from mindx.sdk.base import post  # post.Resnet50PostProcess 为 resnet50 后处理接口


'''初始化资源和变量'''
base.mx_init()  # 初始化 mxVision 资源
pic_path = 'data/T62_1.JPG'  # 单张图片
model_path = "model/resnet18.om"  # 模型路径
device_id = 0  # 指定运算的Device 默认:0
config_path='utils/resnet18.cfg'  # 后处理配置文件
label_path='utils/resnet18_clsidx_to_labels.names'  # 类别标签文件
img_size = 224

'''前处理'''
# 此处的预处理一定要和pth训练时预处理方式一致
img_original = cv2.imread(pic_path, cv2.IMREAD_GRAYSCALE)  # 读取单通道灰度图片
img_gray = cv2.resize(img_original, (img_size, img_size))  # 缩放到目标大小
# 归一化处理:将像素值缩放到[0, 1]范围
img_normalized = img_gray / 255.0
# 均值为0方差为1的归一化处理
img_mean = 0.0  # 目标均值
img_std = 1.0  # 目标标准差
img_normalized = (img_normalized - img_mean) / img_std
img = img_normalized
img = np.expand_dims(img, axis=0)  # 扩展第一维度,适应模型输入
img = np.expand_dims(img, axis=0)  # 扩展第一维度,适应模型输入
# img = img.transpose([0, 3, 1, 2])  # 将 (batch,height,width,channels) 转为 (batch,channels,height,width)
 
img = np.ascontiguousarray(img, dtype=np.float32)# 将内存连续排列
print(img.shape)
# 将归一化后的图像转换为Tensor
img_tensor = Tensor(img)  # 将numpy数组转换为Tensor
#img = img.astype('float32')  # 转换数据类型为float32


'''模型推理'''
#model = base.model(modelPath=model_path, deviceId=device_id)  # 初始化 base.model 类
#output = model.infer([img])[0]  # 执行推理。输入数据类型:List[base.Tensor], 返回模型推理输出的 List[base.Tensor]
'''模型推理'''
model = base.model(modelPath=model_path, deviceId=device_id)  # 初始化 base.model 类
if model:
    # 确保img_tensor是正确的输入格式
    input_tensors = [img_tensor]
    try:
        outputs = model.infer(input_tensors)  # 执行推理
        if outputs:
            output = outputs[0]  # 获取第一个输出张量
            #print(f'output:{output}')
            # ...后续的后处理和错误检查...
        else:
            print("推理返回了一个空outputs结果")
    except Exception as e:
        print(f"推理失败,错误为: {e}")
else:
    print("模型导入失败!请检查model_path和device_id.")

'''后处理'''
postprocessor = post.Resnet50PostProcess(config_path=config_path, label_path=label_path)  # 获取后处理对象
pred = postprocessor.process([output])[0][0]  # 利用sdk接口进行后处理,pred:<ClassInfo classId=... confidence=... className=...>
print(pred)
confidence = pred.confidence  # 获取类别置信度
className = pred.className  # 获取类别名称
print('{}: {}'.format(className, confidence))  # 打印出结果  

'''保存推理图片'''
# 在图像下方新增一块空白区域
height, width = img_original.shape
margin_bottom = 30
new_height = height + margin_bottom
new_img = cv2.copyMakeBorder(img_original, 0, margin_bottom, 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0))

img_res = cv2.putText(new_img, f'{className}: {confidence:.2f}', (0+10, new_height-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)  # 将预测的类别与置信度添加到图片
print(img_res.shape)
cv2.imwrite('resnet18.jpg', img_res)
print('save infer result success')

代码说明:

1、预处理方式一定要和训练时图片预处理结果一致,包括通道数channels(有些图片读取方式会把单通道图读成3通道图,需通过grayscale相关的方式修改)、图片size(通过resize修改)、tensor、img.shape等。

ToTensor()一定是在输入模型前的最后一步

img.shape可能为[1, 1, w, h],其中第一个‘1’为batchsize,这个应该不是训练设置的batchsize而是pth转换onnx时,dummy_img = torch.Tensor(torch.randn((1, 1, 224, 224)).cuda()) # (batchsize, channels, width, height)的设置有关。第二个‘1’为通道数。

2、由于cv2读取的图片返回值为[w, h],与模型输入要求[batchsize, channels, width, height]不符,可以通过img = np.expand_dims(img, axis=0)  # 扩展第一维度,适应模型输入,来扩展数据维度。

3、img = np.ascontiguousarray(img, dtype=np.float32)# 将内存连续排列,并将数据精度设置为float32

这一句很关键,对于不止1个通道的图片来说,一定要设置连续内存。

如果报错E20240328 23:52:27.393335 326321 MxOmModelDesc.cpp:823] Please check inputTensors datasize: 401408, or inputTensor_: 200704. (Code = 1003, Message = "Invalid Pointer")类似,其中出现了两个关键的数据401408和200704,其存在2倍关系,就可以考虑为数据精度的问题,加上img = img.astype('float32') 或者img = np.ascontiguousarray(img, dtype=np.float32)应该可以解决。

4、模型推理部分,为了更方便看到运行到哪一步报错,所以加了try...except...调试正常后其实就是注释掉的两句代码而已。

5、后处理部分postprocessor = post.Resnet50PostProcess(config_path=config_path, label_path=label_path)为什么使用Resnet50PostProcess()呢,好像是因为MindX SDK为python就提供了这么一个API接口供分类使用。

顺利的话,现在应该可以正常输出了。在将分类结果写在图片上时发现,黑白图片上写字可能不清晰,于是在main.py最后部分的代码中加了句,在图片下方贴上一块空白区域,专门用来写分类结果,就清晰了。目前main.py只能单张图片测试,后续学习如何进行批量图片测试。

哈哈哈完结!第一篇博客顺利完成!如果后边实现了C++的推理方式再追更!

  • 29
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
这里提供一个简单的 PyTorch 实现垃圾分类的示例代码,仅供参考: 1. 数据准备 首先需要准备垃圾分类数据集,可以从网上下载或者自己制作。这里使用的是 Kaggle 上的垃圾分类数据集,下载地址为:https://www.kaggle.com/techsash/waste-classification-data 下载完成后,可以将数据集解压到本地路径 `./data` 下。 2. 模型设计 我们使用 PyTorch 实现一个简单的 CNN 模型,用于对垃圾图片进行分类。模型代码如下: ```python import torch.nn as nn class Net(nn.Module): def __init__(self, num_classes=6): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.relu2 = nn.ReLU() self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.relu3 = nn.ReLU() self.pool3 = nn.MaxPool2d(kernel_size=2) self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.bn4 = nn.BatchNorm2d(256) self.relu4 = nn.ReLU() self.pool4 = nn.MaxPool2d(kernel_size=2) self.fc1 = nn.Linear(256 * 6 * 6, 1024) self.relu5 = nn.ReLU() self.dropout = nn.Dropout(p=0.5) self.fc2 = nn.Linear(1024, num_classes) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.pool1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) x = self.pool2(x) x = self.conv3(x) x = self.bn3(x) x = self.relu3(x) x = self.pool3(x) x = self.conv4(x) x = self.bn4(x) x = self.relu4(x) x = self.pool4(x) x = x.view(-1, 256 * 6 * 6) x = self.fc1(x) x = self.relu5(x) x = self.dropout(x) x = self.fc2(x) return x ``` 3. 训练模型 准备好数据集和模型之后,我们可以开始训练模型了。这里使用 PyTorch 提供的 DataLoader 工具来加载数据集,并使用交叉熵作为损失函数,Adam 作为优化器。 ```python import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from torchvision.transforms import transforms # 数据预处理 data_transforms = { 'train': transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), } # 加载数据集 data_dir = './data' image_datasets = {x: ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes # 定义模型 model = Net(num_classes=len(class_names)) model = model.cuda() # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) # 训练模型 num_epochs = 10 best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('-' * 10) # 训练阶段 model.train() running_loss = 0.0 running_corrects = 0 for inputs, labels in dataloaders['train']: inputs = inputs.cuda() labels = labels.cuda() optimizer.zero_grad() outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes['train'] epoch_acc = running_corrects.double() / dataset_sizes['train'] print('Train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc)) # 验证阶段 model.eval() running_loss = 0.0 running_corrects = 0 for inputs, labels in dataloaders['val']: inputs = inputs.cuda() labels = labels.cuda() with torch.no_grad(): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes['val'] epoch_acc = running_corrects.double() / dataset_sizes['val'] print('Val Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc)) # 保存最好的模型参数 if epoch_acc > best_acc: best_acc = epoch_acc torch.save(model.state_dict(), 'best_model.pth') ``` 4. 测试模型 训练完成后,我们可以使用测试集对模型进行测试,代码如下: ```python model = Net(num_classes=len(class_names)) model.load_state_dict(torch.load('best_model.pth')) model = model.cuda() model.eval() running_corrects = 0 for inputs, labels in dataloaders['test']: inputs = inputs.cuda() labels = labels.cuda() with torch.no_grad(): outputs = model(inputs) _, preds = torch.max(outputs, 1) running_corrects += torch.sum(preds == labels.data) test_acc = running_corrects.double() / dataset_sizes['test'] print('Test Acc: {:.4f}'.format(test_acc)) ``` 这样就完成了 PyTorch 实现垃圾分类的示例代码。需要注意的是,这里只是一个简单的示例,实际应用中还需要根据具体情况进行调整和优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值