MATLAB上丝滑地调用自己训练的pytorch神经网络模型
一、环境配置准备
在MATLAB上加载pytorch环境
在MATLAB中运行pytorch的虚拟环境,需要我们自己在MATLAB的命令行输入:
pyversion 'C:\\Users\\31105\\.conda\\...\\python.exe'
中间的路径是虚拟环境的绝对路径。
如果忘记了虚拟环境在哪,要找这个路径,可以有两种方法
1, 在VScode里鼠标停在内核按键就会直接显示出来
2,在cmd中:
conda activate shi #(shi是虚拟环境名字)
Python
import torch
print(torch.path)
然后顺着torch的路径往回找就能找到python.exe的路径了
使用pth文件将训练好的模型的参数保存
这个方法应该很好找,本文就不展开讲了
二、调用文件的准备
本文也是参考了这篇文章 MATLAB调用Pytorch神经网络模型进行预测 不过为了适配自己的项目,就一些细节做了调整。在模型方面,本文将展示基于MLP模型的调用文件的具体代码。
module.py文件的准备
创建一个module.py文件。这个文件所描述的是模型的框架,具体如下:
from torch import nn
import torch.nn.functional as f
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.linear_relu_stack = nn.Sequential(
nn.Linear(1401,800),
nn.ReLU(),
nn.Linear(800, 600),
nn.ReLU(),
nn.Linear(600, 400),
)
#self.initialize_weights()
#self.initialize_parameters()
def forward(self, x):
logits = self.linear_relu_stack(x)
return logits
model_pre.py文件的准备
创建一个model_pre.py文件,这个部分就是将训练好的模型的参数代入到模型框架中。这个加载函数的输入值:params_path就是保存好的pth的模型参数文件的路径。
import torch
import numpy as np
from module import NeuralNetwork
def model_load(params_path):
model = NeuralNetwork()#定义模型
model.load_state_dict(torch.load(params_path, map_location=torch.device('cpu')))#加载模型参数
model.eval()#设置为测试模式
return model
pred.py文件的准备
创建一个文件专门用来预测,本文特意单独开的这个预测文件,这样子可以使每次的预测的时间开销小些。这个预测函数的输入值:model就是前文加载好的模型,data就是所需要预测的数据。
import torch
import numpy as np
def prediction(model, data):
data = torch.tensor(data, dtype=torch.float32)#转为tensor
with torch.no_grad():#不计算梯度,因为是预测阶段
pred = model(data)#预测
pred = pred.numpy()#转为numpy数组
return pred
三、MATLAB上的调用
首先将上述所有文件保存在MATLAB的工作文件夹中,然后在MATLAB上加载这三个文件,因为本文的内置函数需要多次调用,故用全局变量将加载和预测工作分开。
下面在MATLAB上编写如下代码。
加载代码:
global predi;
global model;
MLP = py.importlib.import_module('module');%读取
py.importlib.reload(MLP);%加载
mode = py.importlib.import_module('model_pre');%读取
py.importlib.reload(mode);%加载
model = mode.model_load(pyargs('params_path', 'D:\\MATLAB2023\\bin\\...\\MLP_4000.pth'));
predi = py.importlib.import_module('pred');%读取
py.importlib.reload(predi);%加载
预测代码:
global predi;
global model;
f_pred = predi.prediction(pyargs('model',model,'data', data1));%预测
f_pred = double(f_pred);
在作者的项目中,该方法调用的速度较快,可以满足正常程序对于速度的需求。
至此,这个方法就已经完整说明了。
祝你顺利地完成上述操作,在MATLAB上丝滑地调用自己训练好的pytorch模型!