matlab调用pytorch模型

参考链接MATLAB调用Pytorch神经网络模型进行预测 - 知乎 (zhihu.com)

方法

m文件,读取调用模型所需的py文件module.py和model_pred.py。原链接说是“如果Python的代码修改过是需要重新加载的,为了方便调试,避免频繁报错,直接在m中对两个Python文件都使用reload重加载”。

function result = ModelForPy(params_path, params, data_input)
    module= py.importlib.import_module('module');
    py.importlib.reload(module);
    model_pred= py.importlib.import_module('model_pred');
    py.importlib.reload(model_pred);
    result = model_pred.prediction(params, pyargs('params_path', params_path, 'data_input', data_input));
end

model_pred.py为测试相关方法,原链接数据是在输入前转置+py文件中读取语句转置,我都去掉了,暂时没遇到问题。之前没注意说传进来都是double,所以把params都单独拿出来传递,还是double类型,直接类型转换了,就没改过去,可以参考原链接放在pyargs里。

def prediction(params, params_path, data_input):
    params = int(params) # 输入的数据类型转整数否则是浮点
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #和模型训练一致
    model = NetName(params, device) #数据类型需要
    model.load_state_dict(torch.load(params_path, map_location=torch.device('cpu')))
    model.eval()
    data_input= np.asarray(data_input)
    data_input= np.ascontiguousarray(data_input) #数据正常读取,之前没有这个一直报错
    data_input = torch.from_numpy(np.expand_dims(y_input[0],axis=0)).to(device)
    with torch.no_grad():
        out_model = model(data_input)
    out_model = out_model.cpu().numpy()
    return out_model 

module.py为训练模型

class NetName(nn.Module):
    def __init__(self,params,device):
        super(NeuralNetwork, self).__init__()
            self.params = params
            self._W1 = nn.Linear(params, params, bias=False)
        )
    def forward(self, data_input):
        x = self._W1 @ data_input
        return x

我是训练了两个模型一起用的,模型的调用、测试都和上面一样,在matlab中有数据可以直接调用ModelForPy.m批量测试,输出结果直接用就行。

遇到的问题

matlab调用py的环境

MATLAB配置Python环境-全网最清楚_怎么在matlab中配置python-CSDN博客
1、python和matlab的版本要适配
2、将python路径加入配置环境
3、pyenv(‘Version’,‘python路径\python.exe’)

数据传递

到matlab要这样转一下
    data_input= np.asarray(data_input)
    data_input= np.ascontiguousarray(data_input)

其他的应该就是单独在python上的问题,最大的问题就是数据类型对不上,要仔细检查一下。

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值