#include #include #include #include #include #include
// file read buffer toolclass BufferFile
{
public:
std::string file_path_;
std::size_t length_ = 0;
std::unique_ptr buffer_;
explicit BufferFile(const std::string &file_path)
: file_path_(file_path)
{
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs)
{
std::cerr << "Can't open the file. Please check " << file_path << std::endl;
return;
}
ifs.seekg(0, std::ios::end);
length_ = static_cast<:size_t>(ifs.tellg());
ifs.seekg(0, std::ios::beg);
std::cout << file_path.c_str() << " ... " << length_ << " bytes\n";
buffer_.reset(new char[length_]);
ifs.read(buffer_.get(), length_);
ifs.close();
}
std::size_t GetLength()
{
return length_;
}
char* GetBuffer()
{
return buffer_.get();
}
};
int main(int argc, char* argv[])
{
// model file path std::string json_file = "model/simple_net-symbol.json";
std::string param_file = "model/simple_net-0020.params";
// read model file BufferFile json_data(json_file);
BufferFile param_data(param_file);
if (json_data.GetLength() == 0 || param_data.GetLength() == 0)
{
return EXIT_FAILURE;
}
// mxnet parameters int dev_type = 1; // 1: cpu, 2: gpu, we can change int dev_id = 0; // arbitrary. mx_uint num_input_nodes = 1; // 1 for feedforward const char *input_key[1] = { "data" };
const char **input_keys = input_key;
// define input data shape, notice this must be identical const mx_uint input_shape_indptr[2] = { 0, 2 }; // column dim is 2 const mx_uint input_shape_data[2] = { 3, 2 }; // 3 x 2 matrix input data shape
// global predicator handler PredictorHandle pred_hnd = nullptr;
// create predictor MXPredCreate(static_cast(json_data.GetBuffer()),
static_cast(param_data.GetBuffer()),
static_cast(param_data.GetLength()),
dev_type,
dev_id,
num_input_nodes,
input_keys,
input_shape_indptr,
input_shape_data,
&pred_hnd);
if (!pred_hnd)
{
std::cerr << "Failed to create predict handler" << std::endl;
return EXIT_FAILURE;
}
// prepare test data std::vector input_data{3, 5, 6, 10, 13, 7};
// set input data for mxnet MXPredSetInput(pred_hnd, "data", input_data.data(), input_data.size());
// do predict forward in mxnet model MXPredForward(pred_hnd);
mx_uint output_index = 0;
mx_uint *output_shape = nullptr;
mx_uint ouput_shape_len;
// get output result MXPredGetOutputShape(pred_hnd, output_index, &output_shape, &ouput_shape_len);
std::size_t size = 1;
for (mx_uint i = 0; i < ouput_shape_len; ++i) { size *= output_shape[i]; }
// construct output data from size std::vector output_data(size);
MXPredGetOutput(pred_hnd, output_index, &(output_data[0]), static_cast(size));
// release preditor MXPredFree(pred_hnd);
// print output data std::cout << "the result calculated by trained simple net: " << std::endl;
for (int i = 0; i < output_data.size(); i++)
std::cout << output_data[i] << std::endl;
return EXIT_SUCCESS;
}