python预测模型导出_mxnet训练/导出/加载模型并预测(python和C++)

本文展示了如何使用C++加载由MxNet训练的Python模型,并进行预测。通过创建`BufferFile`类读取模型文件,然后利用`MXPredCreate`、`MXPredSetInput`、`MXPredForward`和`MXPredGetOutput`等函数进行预测操作。
摘要由CSDN通过智能技术生成

#include #include #include #include #include #include

// file read buffer toolclass BufferFile

{

public:

std::string file_path_;

std::size_t length_ = 0;

std::unique_ptr buffer_;

explicit BufferFile(const std::string &file_path)

: file_path_(file_path)

{

std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);

if (!ifs)

{

std::cerr << "Can't open the file. Please check " << file_path << std::endl;

return;

}

ifs.seekg(0, std::ios::end);

length_ = static_cast<:size_t>(ifs.tellg());

ifs.seekg(0, std::ios::beg);

std::cout << file_path.c_str() << " ... " << length_ << " bytes\n";

buffer_.reset(new char[length_]);

ifs.read(buffer_.get(), length_);

ifs.close();

}

std::size_t GetLength()

{

return length_;

}

char* GetBuffer()

{

return buffer_.get();

}

};

int main(int argc, char* argv[])

{

// model file path std::string json_file = "model/simple_net-symbol.json";

std::string param_file = "model/simple_net-0020.params";

// read model file BufferFile json_data(json_file);

BufferFile param_data(param_file);

if (json_data.GetLength() == 0 || param_data.GetLength() == 0)

{

return EXIT_FAILURE;

}

// mxnet parameters int dev_type = 1; // 1: cpu, 2: gpu, we can change int dev_id = 0; // arbitrary. mx_uint num_input_nodes = 1; // 1 for feedforward const char *input_key[1] = { "data" };

const char **input_keys = input_key;

// define input data shape, notice this must be identical const mx_uint input_shape_indptr[2] = { 0, 2 }; // column dim is 2 const mx_uint input_shape_data[2] = { 3, 2 }; // 3 x 2 matrix input data shape

// global predicator handler PredictorHandle pred_hnd = nullptr;

// create predictor MXPredCreate(static_cast(json_data.GetBuffer()),

static_cast(param_data.GetBuffer()),

static_cast(param_data.GetLength()),

dev_type,

dev_id,

num_input_nodes,

input_keys,

input_shape_indptr,

input_shape_data,

&pred_hnd);

if (!pred_hnd)

{

std::cerr << "Failed to create predict handler" << std::endl;

return EXIT_FAILURE;

}

// prepare test data std::vector input_data{3, 5, 6, 10, 13, 7};

// set input data for mxnet MXPredSetInput(pred_hnd, "data", input_data.data(), input_data.size());

// do predict forward in mxnet model MXPredForward(pred_hnd);

mx_uint output_index = 0;

mx_uint *output_shape = nullptr;

mx_uint ouput_shape_len;

// get output result MXPredGetOutputShape(pred_hnd, output_index, &output_shape, &ouput_shape_len);

std::size_t size = 1;

for (mx_uint i = 0; i < ouput_shape_len; ++i) { size *= output_shape[i]; }

// construct output data from size std::vector output_data(size);

MXPredGetOutput(pred_hnd, output_index, &(output_data[0]), static_cast(size));

// release preditor MXPredFree(pred_hnd);

// print output data std::cout << "the result calculated by trained simple net: " << std::endl;

for (int i = 0; i < output_data.size(); i++)

std::cout << output_data[i] << std::endl;

return EXIT_SUCCESS;

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值