感谢:
1 非常感谢宇超同学指出为output_tensor开辟内存空间会覆盖结果的错误,不然可能明天都找不到运行结果错误的原因。
2 感谢OPENAILAB在github提供的代码框架,想看移步此处:https://github.com/OAID/Tengine
以下是 mnist.cpp 的代码
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Copyright (c) 2018, Open AI Lab
* Author: chunyinglv@openailab.com
*/
#include <unistd.h>
#include <iostream>
#include <functional>
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <time.h>
#include "tengine_c_api.h"
#include "opencv2/imgproc/imgproc.hpp"
#include "opencv2/highgui/highgui.hpp"
#include "common_util.hpp"
#define PRINT_TOP_NUM 3
const float mean[3]={127.5,127.5,127.5};
using namespace TEngine;
void get_input_data(std::string image_file, float *input_data,
int img_h, int img_w, const float* mean, float scale)
{
cv::Mat sample = cv::imread(image_file, -1);
if (sample.empty())
{
std::cerr << "Failed to read image file " << image_file << ".\n";
return;
}
cv::Mat img;
if (sample.channels() == 4)
{
cv::cvtColor(sample, img, cv::COLOR_BGRA2BGR);
}
else if (sample.channels() == 1)
{
cv::cvtColor(sample, img, cv::COLOR_GRAY2BGR);
}
else
{
img=sample;
}
cv::resize(img, img, cv::Size(img_h, img_w));
img.convertTo(img, CV_32FC3);
float *img_data = (float *)img.data;
int hw = img_h * img_w;
for (int h = 0; h < img_h; h++)
{
for (int w = 0; w < img_w; w++)
{
for (int c = 0; c < 3; c++)
{
input_data[c * hw + h * img_w + w] = (*img_data - mean[c])*scale;
img_data++;
}
}
}
}
static inline std::vector<int> Argmax1(const std::vector<float> &v, int N)
{
std::vector<std::pair<float, int>> pairs;
for (size_t i = 0; i < v.size(); ++i)
pairs.push_back(std::make_pair(v[i], i));
std::partial_sort(pairs.begin(), pairs.begin() + N, pairs.end(), PairCompare);
std::vector<int> result;
for (int i = 0; i < N; ++i)
result.push_back(pairs[i].second);
return result;
}
int main(int argc, char * argv[])
{
std::string proto_name, mdl_name;
// init tengine
init_tengine();
//init_tengine_library();
std::string img_name = "00226.png";
// load model and create_graph
load_model("model", "tengine", "mnist.tmfile");
graph_t graph = create_runtime_graph("graph","model", NULL);//创建运行图
if (graph == nullptr)
{
std::cout << "Create Graph failed\n";
std::cout << "errno: " << get_tengine_errno() << "\n";
return 1;
}
//define inputdata and malloc memory and get input data
//设置输入图的大小
int img_h = 28;
int img_w = 28;
// set input shape
int img_size = img_h * img_w * 3;
int dims[] = { 1, 3, img_h, img_w };
float* input_data = (float*)malloc(sizeof(float) * img_size);
float* output_data = (float*)malloc(sizeof(float) * img_size);
float scale = 1.f / 255;
tensor_t input_tensor = get_graph_tensor(graph, "data");
//set tensor shape
set_tensor_shape(input_tensor, dims, 4);
//prerun graph
prerun_graph(graph);
//set tensor buffer
get_input_data(img_name, input_data, img_h, img_w, mean, scale);
set_tensor_buffer(input_tensor, input_data, img_size * 4);
//run graph
run_graph(graph,1);
//define and init output tensor
tensor_t output_tensor = get_graph_output_tensor(graph, 0, 0);
float* data = (float*)get_tensor_buffer(output_tensor);
float *end = data + 10;
std::vector<float> result(data, end);
std::vector<int> top_N = Argmax1(result, PRINT_TOP_NUM);
for (unsigned int i = 0; i < top_N.size(); i++)
{
int idx = top_N[i];
std::cout<<"Predict Result:" <<idx << "\n";
}
put_graph_tensor(output_tensor);
put_graph_tensor(input_tensor);
free(input_data);
//postrun graph
postrun_graph(graph);
//destroy runtime graph
destroy_runtime_graph(graph);
//remove model
/* Tengine -- deinitialization */
release_tengine_library();
return 0;
}
问