利用c++/c#实现mnist手写字符识别,包括模型训练、推理预测,依赖简单,开箱即用,全部通过代码实现,支持二次开发,以及gpu加速。
1.准备工作
- 环境:win10+vs2015/x64(≥vs2015均可以)+cuda(optional)
- 下载相关的依赖库yilecv.zip和mnist数据,下载地址见后
2.C++实现训练、推理
- 新建vc++空项目
- 新建源文件(用于训练或者推理)
训练代码
// train.cpp
#include <string>
#include <iostream>
#include "trainer.hpp"
int main(int argc, char** argv)
{
yilecv::Trainer trainer;
std::string model_file = "lenet_train_test.prototxt";
std::string solver_file = "lenet_solver.prototxt";
trainer.Init(model_file, solver_file, 0, "mnist_output");
trainer.SetMetricBlobName("accuracy", true);
trainer.Train();
std::cout << "trainer.BestMetricValue=" << trainer.BestMetricValue() << std::endl;
}
推理代码
// predict.cpp
#include <iostream>
#include <string>
#include <vector>
#include "predictor.hpp"
int main(int argc, char** argv)
{
yilecv::Predictor predictor;
std::string model_file = "lenet_deploy.prototxt";
predictor.Init(model_file, "mnist_output/best.bin", 0);
predictor.SetNormScaleCoeff({0.00390625,0.00390625,0.00390625});
std::vector<float> out = predictor.Predict("dataset/mnist_images/test/test_0_7.jpg");
std::vector<int> maxN = predictor.PredictMaxN(out, 1);
for (int i = 0; i < maxN.size(); ++i)
{
std::cout << maxN[i] << ":" << out[maxN[i]] << std::endl;
}
}
- 配置yilecv库
右键项目,选择属性=>c/c++=>常规=>附加包含目录,设置为yilecv库下的include目录
右键项目,选择属性=>链接器=>常规=>附加库目录,设置为yilecv库下的lib目录
右键项目,选择属性=>链接器=>输入=>附加依赖项,设置为yilecv.lib
- 拷贝yilecv/lib目录下的dll和网络模型相关文件到对应源文件目录,选择release/x64,运行即可
3.C#实现训练、推理
-
新建c#项目
-
添加YileCVSharp.cs文件到项目中
-
新建源文件,拷贝yilecv/lib目录下的dll和网络模型相关文件到对应bin目录下,选择release/x64,运行即可
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using YileCVSharp;
namespace YileCVSharpExample
{
class Program
{
static void Main(string[] args)
{
// 训练
//Trainer trainer = new Trainer();
//string model_file = "lenet_train_test.prototxt";
//string solver_file = "lenet_solver.prototxt";
//trainer.Init(model_file, solver_file, 0, "mnist_output");
//trainer.SetMetricBlobName("accuracy", true);
//trainer.Train();
//System.Console.WriteLine("trainer.BestMetricValue =" + ": " + trainer.BestMetricValue());
// 推理预测
Predictor predictor = new Predictor();
predictor.Init("lenet_deploy.prototxt", "mnist_output/best.bin", 0);
VectorFloat scale = new VectorFloat();
scale.Add(0.00390625f);
scale.Add(0.00390625f);
scale.Add(0.00390625f);
predictor.SetNormScaleCoeff(scale);
VectorFloat output = predictor.Predict("mnist_images/test/test_0_7.jpg");
VectorInt maxN = predictor.PredictMaxN(output, 1);
for (int i = 0; i < maxN.Capacity; ++i)
{
System.Console.WriteLine(maxN[i] + ": " + output[maxN[i]]);
}
System.Console.ReadKey();
}
}
}
4.下载地址
配置好的c#工程(包含mnist数据):https://download.csdn.net/download/u012594175/89165417
配置好的c++工程(包含mnis数据):https://download.csdn.net/download/u012594175/89165559
交流QQ群