概述:
通过 openCV 提供的 DNN训练模型,实现在图片中寻找数字并识别;
createLabel.cpp
从 digit.png
分割出数字块作为训练素材,再通过人工识别标定对应的数值,创建 LabelData.yml
。
熟悉整个数字识别过程后可以发现,
createLabel.cpp
其实可以做出更多的改进,如增加手写体数字/畸形数字/其他字符/精简素材库;
trainTest.cpp
根据上个程序得到的 yml 文件训练模型,并对 dig.png
图片进行识别,在控制台中输出识别结果,同时创建一个等大小的图片标识出识别的区域;
笔者认为这步的核心在于
knn->findNearest()
函数;
环境:
配置好最新版本的 opencv 即可;
代码:
createLabel.cpp
#include "opencv2/opencv.hpp"
#include <opencv2/imgproc/types_c.h>
#include <iostream>
using namespace cv;
using namespace std;
int main(int argc, char** argv)
{
// 图片二值化,方便后续取轮廓
Mat thr, gray, con;
Mat src = imread("digit.png", 1); // 图片要和可执行文件放在同一目录,或者改成绝对路径
cvtColor(src, gray, COLOR_BGR2GRAY);
threshold(gray, thr, 200, 255, THRESH_BINARY_INV); //Threshold to find contour
thr.copyTo(con);
// 创建数据
vector< vector <Point> > contours; // Vector for storing contour
vector< Vec4i > hierarchy;
Mat sample;
Mat response_array;
findContours(con, contours, hierarchy, CV_RETR_CCOMP, CV_CHAIN_APPROX_SIMPLE); //Find contour
for (int i = 0; i< contours.size(); i = hierarchy[i][0]) // iterate through first hierarchy level contours
{
Rect r = boundingRect(contours[i]); //Find bounding rect for each contour
// rectangle() 会在图片上绘出最小外矩形
rectangle(src, Point(r.x, r.y), Point(r.x + r.width, r.y + r.height), Scalar(0, 0, 255), 2, 8, 0);
Mat ROI = thr(r); //Crop the image
Mat tmp1, tmp2;
resize(ROI, tmp1, Size(10, 10), 0, 0, INTER_LINEAR); //resize to 10X10
tmp1.convertTo(tmp2, CV_32FC1); //convert to float
sample.push_back(tmp2.reshape(1, 1)); // Store sample data
imshow("src", src); // 展示标出的方框,红框是接下来要标定的数字
waitKey(500); // 在 imshow() 最好有个 waitKey(), 否则图片容易刷新延时
int num;
cin >> num; // 输入你识别到的数字(红框内)
response_array.push_back(num); // Store label to a mat
rectangle(src, Point(r.x, r.y), Point(r.x + r.width, r.y + r.height), Scalar(0, 255, 0), 2, 8, 0);
}
// Store the data to file
Mat response, tmp;
tmp = response_array.reshape(1, 1); //make continuous
tmp.convertTo(response, CV_32FC1); // Convert to float
FileStorage Data("TrainingData.yml", FileStorage::WRITE); // Store the sample data in a file
Data << "data" << sample;
Data.release();
FileStorage Label("LabelData.yml", FileStorage::WRITE); // Store the label data in a file
Label << "label" << response;
Label.release();
cout << "Training and Label data created successfully....!! " << endl;
imshow("src", src);
cout << "success and end" << endl;
return 0;
}
trainTest.cpp
#include "opencv2/opencv.hpp"
#include<opencv2/ml/ml.hpp>
#include <opencv2/imgproc/types_c.h>
#include <iostream>
using namespace cv;
using namespace std;
int main(int argc, char** argv)
{
Mat thr, gray, con;
Mat src = imread("dig.png", 1); // 图片要和可执行文件放在同一目录,或者改成绝对路径
cvtColor(src, gray, CV_BGR2GRAY);
threshold(gray, thr, 200, 255, THRESH_BINARY_INV); // Threshold to create input
thr.copyTo(con);
// Read stored sample and label for training
Mat sample;
Mat response, tmp;
FileStorage Data("TrainingData.yml", FileStorage::READ); // Read traing data to a Mat
Data["data"] >> sample;
Data.release();
FileStorage Label("LabelData.yml", FileStorage::READ); // Read label data to a Mat
Label["label"] >> response;
Label.release();
Ptr<ml::KNearest> knn(ml::KNearest::create());
knn->train(sample, ml::ROW_SAMPLE,response); // Train with sample and responses
cout << "Training compleated.....!!" << endl;
vector< vector <Point> > contours; // Vector for storing contour
vector< Vec4i > hierarchy;
// 在图片中寻找轮廓
findContours(con, contours, hierarchy, CV_RETR_CCOMP, CV_CHAIN_APPROX_SIMPLE);
Mat dst(src.rows, src.cols, CV_8UC3, Scalar::all(0));
for (int i = 0; i< contours.size(); i = hierarchy[i][0]) // 逐个识别轮廓内数字
{
// cout << "begin analyze" << endl;
Rect r = boundingRect(contours[i]);
/*
* 这是笔者用来筛除测试案例中下划线的判断
* if (r.height > 3 * r.width || r.width > 3 * r.height) continue; // 跳过下划线
*/
rectangle(dst, Point(r.x, r.y), Point(r.x + r.width, r.y + r.height), Scalar(0, 0, 255), 2, 8, 0);
Mat ROI = thr(r);
Mat tmp1, tmp2;
resize(ROI, tmp1, Size(10, 10), 0, 0, INTER_LINEAR);
tmp1.convertTo(tmp2, CV_32FC1);
Mat response;
float p = knn->findNearest(tmp2.reshape(1, 1), 1, response); // 识别数字
// 在控制行中输出识别到的数字
cout << (int)p << endl;
char name[4];
sprintf(name, "%d", (int)p);
// 将识别到的数字标识到输出图片上
putText(dst, name, Point(r.x, r.y + r.height), 0, 1, Scalar(0, 255, 0), 2, 8);
}
imshow("src", src);
imshow("dst", dst);
imwrite("dest.jpg", dst);
return 0;
}