大数据最新PyTorch的ONNX结合MNIST手写数字数据集的应用((1),2024年最新原生大数据开发开发的路该怎么走

img
img

网上学习资料一大堆,但如果学到的知识不成体系,遇到问题时只是浅尝辄止,不再深入研究,那么很难做到真正的技术提升。

需要这份系统化资料的朋友,可以戳这里获取

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

我们来看一个在CPU的环境下的加载方法,mnist.pth文件下载地址:mnist.pth

import torch
model=torch.load("mnist.pth")
print(type(model['net']),len(model['net']))
for k,v in model['net'].items():
    print(k,v.size())

'''
<class 'collections.OrderedDict'> 10
conv1.weight torch.Size([6, 1, 3, 3])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 3, 3])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
'''

这里可以不指定map_location参数,默认是cpu设备,可以看到这个pth文件结构是两个卷积层加三个全连接层。

3、pth转onnx

我们根据上面的mnist.pth结构,自己来构造一个模型:

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=3,stride=1,padding=0)
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=3,stride=1,padding=0)
        self.fc1   = nn.Linear(400, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)
 
    def forward(self, x):
        out = self.conv1(x) # torch.Size([1, 6, 26, 26])
        out = F.max_pool2d(F.relu(out), 2) # [1, 6, 13, 13]
        out = self.conv2(out) # [1, 16, 11, 11]
        out = F.max_pool2d(F.relu(out), 2)  # [1, 16, 5, 5]
        out = out.view(out.size(0), -1) # [1, 400]
        out = self.fc1(out) # [1, 120]
        out = self.fc2(F.relu(out)) # [1, 84]
        out = self.fc3(F.relu(out))  # [1, 10]
        return out

net = LeNet()
net = net.to('cpu')
checkpoint = torch.load('mnist.pth')
net.load_state_dict(checkpoint['net'])
batch_size = 1
input_shape = (1,28,28)
x = torch.randn(batch_size,*input_shape)
net.eval()
torch.onnx.export(net,x,"mnist.onnx")

构造一样的结构,加载mnist.pth,然后就可以通过export转换成onnx格式的文件了。我们上传到https://netron.app/ 站点,可视化整个模型图,然后点击每个节点,将在右边出现它们的属性值:

4、onnx运行时

onnxruntime主要是拿来推理,当然在ir7的版本也增加了训练等功能,我们来了解下这个东西

4.1、安装模块

如果缺少onnxruntime模块,就会报错:

ModuleNotFoundError: No module named ‘onnxruntime’

这里在JupyterLab中,所以在前面加一个叹号安装

!pip install onnxruntime -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

import torch
import onnxruntime as ort
import numpy as np

session = ort.InferenceSession("mnist.onnx")
x = np.random.rand(1, 1, 28, 28).astype(np.float32)
outputs = session.run(None, {"input": x})
print(outputs[0])
4.2、名称一致

这里容易出错:InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:input

也就是说这个sess.run([output_name], {input_name: x})中的输入名称错误,所以名称要一样,这里的输入名称是input.1,修改成**outputs = session.run(None, {“input.1”: x})**就可以了
怎么查看名称,可以通过上面站点可视化直接看到名称,也可以使用下面代码获取

input_name = session.get_inputs()
print(input_name[0].name)#input.1

同样的,如果输出名称也想指定,可以使用下面代码获取

out_name = session.get_outputs()[0].name
4.3、三通道转一通道

彩色三通道的图片转成灰色的单通道图片:

import cv2
import numpy as np
img = cv2.imread('1.png', cv2.IMREAD_GRAYSCALE)
cv2.imwrite('1.jpg',img)
print(img.shape)#(28, 28)

5、转成json格式

有时候的需求需要可读文件,一般json是很常见的,也可以进行转换:

img
img

网上学习资料一大堆,但如果学到的知识不成体系,遇到问题时只是浅尝辄止,不再深入研究,那么很难做到真正的技术提升。

需要这份系统化资料的朋友,可以戳这里获取

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

这里获取](https://bbs.csdn.net/forums/4f45ff00ff254613a03fab5e56a57acb)**

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值