题目要求:在上一章OpenCV----简单目标提取和分割中尝试使用opencv连通性方法获取了目标的面积和轮廓信息,本章节将尝试对这些特征进行整合,使用opencv中ml库(machine learning)训练一个目标分类器,给定一种输入图片,预测图像中目标的类别。
最新代码库已更新:opencv4_cpp
分析:
1)了解常用的分类模型,参见opencv ml API OpenCV namespace ml,首先建议一个简单的SVM分类器,然后设计其他的分类器如朴素贝叶斯模型等;
2)设计命令行设置输入图片和模型选择等,支持默认模型选择;
3)考虑模板类进行分类模型选择的通用能力;
4)支持模型参数手动设定;
5)进代码设计成便于管理,解释性强的工程,设计通用的工程模板;
-
文件工程目录
·bin ------------------生成可执行文件目录
·build ----------------文件编译的中间结果
·data -----------------训练测试数据(图像)
·include -------------头文件目录和inline文件(用于模板方法实现)
·lib --------------------用于生成头文件的lib文件
·model --------------不同模型训练生成的xml文件
·src -------------------源文件和源文件编译命令CMakeLists.txt
·CMakeLists.txt —外层文件编译命令
-
outer CMakeLists.txt
cmake_minimum_required (VERSION 3.0)
PROJECT(ch6)
set(CMAKE_BUILD_TYPE Release)
set(CMAKE_CXX_FLAGS "-std=c++17 -Wall")
set(CMAKE_CXX_FLAGS_RELEASE "-std=c++17 -O3 -fopenmp -pthread")
IF(EXISTS ${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
ENDIF()
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin)
set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/lib)
include_directories(${PROJECT_SOURCE_DIR}/include)
link_directories(${PROJECT_SOURCE_DIR}/lib)
add_subdirectory(${PROJECT_SOURCE_DIR}/src)
- inner CMakeLists.txt
cmake_minimum_required (VERSION 3.0)
set(CMAKE_BUILD_TYPE Release)
set(CMAKE_CXX_FLAGS "-std=c++17 -Wall")
set(CMAKE_CXX_FLAGS_RELEASE "-std=c++17 -O3 -fopenmp -pthread")
include_directories(${PROJECT_SOURSE_DIR}/include)
# Requires OpenCV
find_package(OpenCV REQUIRED)
message("OpenCV version : ${OpenCV_VERSION}")
include_directories(${OpenCV_INCLUDE_DIRS})
link_directories(${OpenCV_LIB_DIR})
add_library(utils utils.cpp)
add_library(mwindow mwindow.cpp)
add_executable(main main.cpp)
target_link_libraries(main ${OpenCV_LIBS} utils mwindow -lopencv_ml)
- 结果演示
./bin/main data/test.pgm data/pattern.pgm bayes
./bin/main data/test.pgm data/pattern.pgm (svm)
./bin/main data/test.pgm data/pattern.pgm boost
- 代码示例
1)更新之前的多窗口类MWindow Class
/*
@File :mwindow.hpp
@Description: :
@Date :2021/12/25 09:23:14
@Author :xieyin
@version :1.0
*/
#pragma once
#include<iostream>
#include<string>
#include<vector>
using namespace std;
#include<opencv2/core.hpp>
#include<opencv2/highgui.hpp>
using namespace cv;
class MWindow{
public:
// consturtor
MWindow(string windowTitle, int rows, int cols, int height=700, int width=1200, int flags=WINDOW_AUTOSIZE);
// add image into canvas
int addImage(string title, Mat img, bool render = false);
// remove image from canvas
void removeImage(int pos);
// adjust all image size in canvas
void render();
private:
string mWindowTitle;
int mRows;
int mCols;
Mat mCanvas;
vector<string> mSubTitles;
vector<Mat> mSubImages;
};
/*
@File :mwindow.cpp
@Description: :
@Date :2021/12/25 09:23:22
@Author :xieyin
@version :1.0
*/
#include<iostream>
#include<string>
#include<vector>
using namespace std;
#include<opencv2/core.hpp>
#include<opencv2/highgui.hpp>
#include<opencv2/opencv.hpp>
#include<opencv2/imgproc.hpp>
using namespace cv;
#include"mwindow.hpp"
MWindow::MWindow(string windowTitle, int rows, int cols, int height, int width, int flags):mWindowTitle(windowTitle), mRows(rows), mCols(cols){
/*
@description : MWindow constructor
@param :
windowTitle : whole window title
rows : sub window rows
cols : sub window cols
flags : namedWindow flags (eg, WINDOW_AUTOSIZE)
@Returns :
*/
// create canvas
namedWindow(mWindowTitle, flags);
mCanvas = Mat(height, width, CV_8UC3);
imshow(mWindowTitle, mCanvas);
}
int MWindow::addImage(string title, Mat img, bool render){
/*
@description : add title and image into canvas
@param :
title : sub image title
img : image to be added
render : render(flag) whether need to adjust the image for canvas
@Returns :
index : sub image index in total mRows * mCols
*/
int index=-1;
for(int i=0; i<mSubTitles.size(); i++){
string t=this->mSubTitles[i];
if(t.compare(title)==0){
index=i;
break;
}
}
if(index==-1){
mSubTitles.push_back(title);
mSubImages.push_back(img);
}else{
mSubImages[index]= img;
}
if(render){
MWindow::render();
}
return mSubImages.size() - 1;
}
void MWindow::removeImage(int pos){
/*
@description : remove image from canvas based on index
@param :
pos : sub image index in total mRows * mCols
@Returns :
None
*/
mSubTitles.erase(mSubTitles.begin() + pos);
mSubImages.erase(mSubImages.begin() + pos);
}
void MWindow::render(){
/*
@description : fill title and image into canvas in suitable way
@param :
None
@Returns :
None
*/
mCanvas.setTo(Scalar(20, 20, 20));
// get sub canvas size
int cellH = mCanvas.rows / mRows;
int cellW = mCanvas.cols / mCols;
// set total number of images to load
int n = mSubImages.size();
int numImgs = n > mRows * mCols ? mRows * mCols : n;
for(int i = 0; i < numImgs; i++){
// get title
string title = mSubTitles[i];
// get sub canvas top left location
int cellX = (cellW) * ((i) % mCols);
int cellY = (cellH) * floor( (i) / (float) mCols);
Rect mask(cellX, cellY, cellW, cellH);
// set subcanvas size
rectangle(mCanvas, Rect(cellX, cellY, cellW, cellH), Scalar(200, 200, 200), 1);
Mat cell(mCanvas, mask);
Mat imgResz;
// get cell aspect
double cellAspect = (double) cellW / (double) cellH;
// get image
Mat img = mSubImages[i];
// get image aspect
double imgAspect = (double) img.cols / (double) img.cols;
double wAspect = (double) cellW / (double) img.cols;
double hAspect = (double) cellH / (double) img.rows;
// get suitable aspect and resize image
double aspect = cellAspect < imgAspect ? wAspect : hAspect;
resize(img, imgResz, Size(0, 0), aspect, aspect);
// if gray image, convert to BGR
if(imgResz.channels() == 1){
cvtColor(imgResz, imgResz, COLOR_GRAY2BGR);
}
Mat subCell(mCanvas, Rect(cellX, cellY, imgResz.cols, imgResz.rows));
imgResz.copyTo(subCell);
putText(cell, title, Point(20, 20), FONT_HERSHEY_SIMPLEX, 0.6, Scalar(255, 0, 0));
}
// show total canvas
imshow(mWindowTitle, mCanvas);
}
2)设计通用辅助函数utils.hpp, utills.cpp和utils.inl
/*
@File :utils.hpp
@Description: :
@Date :2021/12/25 09:23:48
@Author :xieyin
@version :1.0
*/
#pragma once
#include<string>
#include<cmath>
#include<memory>
using namespace std;
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/ml.hpp>
using namespace cv;
using namespace cv::ml;
#include"mwindow.hpp"
extern shared_ptr<MWindow> myWin;
// generate randow color basd on randow number generator
Scalar randColor(RNG& rng);
// calculate given img's light pattert with large kernel's Blur operation
Mat calLigthPattern(Mat img);
// use 2 tpyes of light removal method, 0 diff, 1 div, defalut is 0
Mat removeLight(Mat img, Mat pattern, int methodLight=0);
// packed opencv lib connectedComponents function
Mat connectedComponents(Mat img_thr);
// packed opencv lib connectedComponentsWithStats function
Mat connectedComponentsWithStats(Mat img_thr);
// packed opencv lib findContours function
Mat findContours(Mat img_thr);
// helper function for trainAndTest, readFolderAndExtractFeatures
bool readFolderAndExtractFeatures(string filePath, int label, int numTest,
vector<float>& trainingData, vector<int>& trainResponses, vector<float>& testData, vector<int>& testResponses);
// helper function for trainAndTest, ploat data error
void plotData(Mat trainingDataMat, Mat trainResponsesMat, string mode="svm", float* error=NULL);
// define svm parameters
void defineSVM(Ptr<SVM>& svm);
// the train and test process for mechain learning
template<typename T>
void trainAndTest(string mode="svm");
// train svm model
void trainSVM();
// predict features extracted from imgOut, and put text in left top position
template<typename T>
void predict(vector<vector<float>> features, vector<int> posLeft, vector<int> posTop, string mode, Mat& imgOut);
// preprocess test image
Mat preProcess(Mat img);
// extract feature from preprocess image and get left top location
vector<vector<float>> extractFeatures(Mat img, vector<int>* posLeft=NULL, vector<int>* posTop=NULL);
#include"utils.inl"
/*
@File :utils.cpp
@Description: :
@Date :2021/12/25 09:23:38
@Author :xieyin
@version :1.0
*/
#include<string>
#include<cmath>
#include<memory>
#include<iostream>
#include<vector>
using namespace std;
#include<opencv2/core.hpp>
#include<opencv2/highgui.hpp>
#include<opencv2/imgproc.hpp>
#include<opencv2/opencv.hpp>
#include<opencv2/ml.hpp>
using namespace cv;
using namespace cv::ml;
#include"utils.hpp"
#include"mwindow.hpp"
Scalar randColor(RNG& rng){
/*
@description : generate randow color
@param :
rng : random number generator object
@Returns :
Sacalar() : BGR scalar
*/
auto iColor = (unsigned)rng;
return Scalar(iColor&255, (iColor >> 8)&255, (iColor >> 16)&255);
}
Mat calLigthPattern(Mat img){
/*
@description : get source image's light pattern
@param :
img : source BGR image or Gray image
@Returns :
pattern : the light pattern
*/
Mat pattern;
blur(img, pattern, Size(img.cols / 3, img.cols / 3));
return pattern;
}
Mat removeLight(Mat img, Mat pattern, int methodLight){
/*
@description : remove light between img and pattern based on method light
@param :
img : source BGR/Gray image
pattern : pattern BGR/Gray image
methodLight : choise options: 0 difference, 1 div
@Returns :
aux : light removed BGR/Gray image
*/
Mat aux;
if(methodLight == 1){
// div operation in float 32 format CV_32F
Mat img32, pattern32;
img.convertTo(img32, 5);
pattern.convertTo(pattern32, 5);
aux = 1.0 - (img32 / pattern32);
// covert to CV_8U and clip
aux.convertTo(aux, 0, 255);
}
else{
// difference
aux = pattern - img;
}
return aux;
}
Mat connectedComponents(Mat img_thr){
/*
@description : opencv connnected components
@param :
img : threshold image
@Returns :
None
*/
Mat labels;
auto num_objs = connectedComponents(img_thr, labels);
Mat res;
if(num_objs < 2){
cout << "no object is detected. " << endl;
return res;
}
res = Mat::zeros(img_thr.rows, img_thr.cols, CV_8UC3);
RNG rng(0xFFFFFFFF);
for(auto i = 1; i < num_objs; i++){
Mat mask = labels == i;
res.setTo(randColor(rng), mask);
}
return res;
}
Mat connectedComponentsWithStats(Mat img_thr){
/*
@description : connnected components with stats
@param :
img : threshold image
@Returns :
None
*/
Mat labels, stats, centroids;
auto num_objs = connectedComponentsWithStats(img_thr, labels, stats, centroids);
Mat res;
if(num_objs < 2){
cout << "no object is detected. " << endl;
return res;
}
res = Mat::zeros(img_thr.rows, img_thr.cols, CV_8UC3);
RNG rng(0xFFFFFFFF);
for(auto i = 1; i < num_objs; i++){
Mat mask = labels == i;
res.setTo(randColor(rng), mask);
stringstream ss;
ss << "area: " << stats.at<int>(i, CC_STAT_AREA);
// add text info
putText(res, ss.str(), centroids.at<Point2d>(i), FONT_HERSHEY_SIMPLEX, 0.3, Scalar(0, 255, 0));
}
return res;
}
Mat findContours(Mat img_thr){
/*
@description : find contours and put text
@param :
img : threshold image
@Returns :
None
*/
vector<vector<Point>> contours;
findContours(img_thr, contours, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE);
Mat res;
if(contours.size() == 0){
cout << "no contours are found ." << endl;
return res;
}
RNG rng(0xFFFFFFFF);
res = Mat::zeros(img_thr.rows, img_thr.cols, CV_8UC3);
// calculate moments
vector<Moments> mu(contours.size());
for (int i = 0; i < contours.size(); i++)
{
mu[i] = moments(contours[i], false);
}
// calculate centroids
vector<Point2f> mc(contours.size());
for (int i = 0; i < contours.size(); i++)
{
mc[i] = Point2d(mu[i].m10 / mu[i].m00, mu[i].m01 / mu[i].m00);
}
for(auto i = 0; i < contours.size(); i++){
drawContours(res, contours, i, randColor(rng));
putText(res, "*", Point(mc[i].x, mc[i].y), FONT_HERSHEY_SIMPLEX, 0.4, Scalar(255, 0, 255), 1);
}
return res;
}
// helper function for readFolderAndExtractFeatures, preprocess image to binary image
Mat preProcess(Mat img){
/*
@description : preprocess img to denoise and remove light
@param :
img : image to process
@Returns :
*/
if(img.channels() == 3){
cvtColor(img, img, COLOR_BGR2GRAY);
}
Mat imgOut, imgNoise, imgLight;
medianBlur(img, imgNoise, 3);
imgNoise.copyTo(imgLight);
// read lightPat
Mat lightPat = imread("data/pattern.pgm", 0);
imgLight = removeLight(imgNoise, lightPat);
threshold(imgLight, imgOut, 30, 255, THRESH_BINARY);
return imgOut;
}
// helper function for trainAndTest, readFolderAndExtractFeatures
bool readFolderAndExtractFeatures(string filePath, int label, int numTest,
vector<float> &trainingData, vector<int> &trainResponses, vector<float> &testData, vector<int> &testResponses){
/*
@description : read file data and extract area and aspect features
@param :
filePath : image file path
label : image lable to classify
numTest : number for test
trainingData : trainingData feature: area, aspect
trainResponses : trainingData label
testData : testData feature: area, aspect
testResponses : testData label
@Returns :
(ref return) : trainingData, trainResponses, testData, testResponses
*/
vector<String> files;
// get folder
glob(filePath, files, true);
Mat frame;
int imgIdx = 0;
if(files.size() == 0){
return false;
}
for(int i = 0; i < files.size(); i++){
frame = imread(files[i]);
// preprocess image
Mat pre = preProcess(frame);
// get n features pair for each image
vector<vector<float>> features = extractFeatures(pre);
for(int i = 0; i < features.size(); i++){
// first numTest for model test
if(imgIdx >= numTest){
trainingData.push_back(features[i][0]);
trainingData.push_back(features[i][1]);
trainResponses.push_back(label);
}else{
testData.push_back(features[i][0]);
testData.push_back(features[i][1]);
testResponses.push_back(label);
}
}
imgIdx++;
}
return true;
}
// helper function for trainAndTest, ploat data error
void plotData(Mat trainingDataMat, Mat trainResponsesMat, string mode, float* error){
/*
@description : ploat train data feature (x: area, y: aspect) distributiion
@param :
trainingDataMat : trainingDataMat shape [N/2, 2], N is trainData vector size
trainResponsesMat : trainResponsesMat shape [N, 1], N is trainData label vector size
error : total error rate to display
@Returns :
None
*/
float areaMax, areaMin, asMax, asMin;
areaMax = asMax = 0.0;
areaMin = asMin = 99999999;
for(int i = 0; i < trainingDataMat.rows; i++){
float area = trainingDataMat.at<float>(i, 0);
float aspect = trainingDataMat.at<float>(i, 1);
// get min, max value
if(area > areaMax){
areaMax = area;
}
if(aspect > asMax){
asMax = aspect;
}
if(areaMin > area){
areaMin = area;
}
if(asMin > area){
asMin = aspect;
}
}
// create image to display
Mat fig = Mat::zeros(512, 512, CV_8UC3);
for(int i = 0; i < trainingDataMat.rows; i++){
float area = trainingDataMat.at<float>(i, 0);
float aspect = trainingDataMat.at<float>(i, 1);
// min-max norm [0~1] * 420 pixel
int x = (int)(420.0f*((area - areaMin) / (areaMax - areaMin)));
int y = (int)(420.0f*((aspect - asMin) / (asMax - asMin)));
int label = trainResponsesMat.at<int>(i);
Scalar color;
if(label == 0){
color = Scalar(255, 0, 0);
}else if(label == 1){
color = Scalar(0, 255, 0);
}else if(label == 2){
color = Scalar(0, 0, 255);
}
// cicle locate with start at(80, 80) to overcome border
circle(fig, Point(x+80, y+80), 3, color, -1, 8);
}
if(error != NULL){
stringstream ss;
ss << mode << " error: " << *error << " \%";
putText(fig, ss.str(), Point(20, 512-40), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(200, 200, 200), 1, LINE_AA);
}
myWin->addImage("Fig", fig);
}
void defineSVM(Ptr<SVM>& svm){
/*
@description : define svm parameters
@param :
svm : svm model
@Returns :
(ref return) : svm with parameters
*/
svm = SVM::create();
svm->setType(SVM::C_SVC);
svm->setNu(0.05);
svm->setKernel(SVM::CHI2);
svm->setDegree(1.0);
svm->setGamma(2.0);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
}
void trainSVM(){
/*
@description : train a svm model and test its error rate
@param :
mode : machine learning mode
@Returns :
None
*/
vector<float> trainingData;
vector<int> trainResponses;
vector<float> testData;
vector<int> testResponses;
int numTest = 20;
string nutPath = "data/nut";
string ringPath = "data/ring";
string screwPath = "data/screw";
// read data path and extract feature
readFolderAndExtractFeatures(nutPath, 0, numTest, trainingData, trainResponses, testData, testResponses);
readFolderAndExtractFeatures(ringPath, 1, numTest, trainingData, trainResponses, testData, testResponses);
readFolderAndExtractFeatures(screwPath, 2, numTest, trainingData, trainResponses, testData, testResponses);
// cout << "Num of train samples: " << trainingData.size() << endl;
// cout << "Num of test samples: " << testData.size() << endl;
Mat trainingDataMat(trainingData.size() / 2, 2, CV_32FC1, &trainingData[0]);
Mat trainResponsesMat(trainResponses.size(), 1, CV_32SC1, &trainResponses[0]);
Mat testDataMat(testData.size() / 2, 2, CV_32FC1, &testData[0]);
Mat testResponsesMat(testResponses.size(), 1, CV_32SC1, &testResponses[0]);
// set row sample
Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, trainResponsesMat);
// select model
Ptr<SVM> model = SVM::create();
defineSVM(model);
model->train(tData);
model->save("model/svm.xml");
if(testResponses.size() > 0){
Mat testPredict;
// predict
model->predict(testDataMat, testPredict);
testPredict.convertTo(testPredict, CV_32SC1);
Mat errMat = testPredict != testResponsesMat;
float error = 100.0f * countNonZero(errMat) / testResponses.size();
cout << "svm" << " Error rate: " << error << "\%" << endl;
plotData(trainingDataMat, trainResponsesMat, "svm", &error);
}
else{
plotData(trainingDataMat, trainResponsesMat, "svm");
}
}
vector<vector<float>> extractFeatures(Mat img, vector<int>* posLeft, vector<int>* posTop){
/*
@description : extract image features and get left top loation
@param :
img : image to get feature
postLeft : left top_left location
postTop : top top_left location
@Returns :
features: extracted feature
*/
vector<vector<float>> features;
vector<vector<Point>> contours;
vector<Vec4i> hierarchy;
Mat temp = img.clone();
// find contours
findContours(temp, contours, hierarchy, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE);
if(contours.size() == 0){
return features;
}
for(int i = 0; i < contours.size(); i++){
Mat mask = Mat::zeros(img.rows, img.cols, CV_8UC1);
// draw contours
drawContours(mask, contours, i, Scalar(1), FILLED, LINE_8, hierarchy, 1);
// get area value
Scalar areaSum = sum(mask);
float area = areaSum[0];
if(area > 500){
// calculate aspect for area is larger than 500
RotatedRect r = minAreaRect(contours[i]);
float w = r.size.width;
float h = r.size.height;
float aspect = w < h ? h / w : w / h;
vector<float> row;
// load calculated feature
row.push_back(area);
row.push_back(aspect);
features.push_back(row);
// load top_left location
if(posLeft != NULL){
posLeft->push_back((int)r.center.x);
}
if(posTop != NULL){
posTop->push_back((int)r.center.y);
}
myWin->addImage("Extracted Feature", mask * 255);
myWin->render();
waitKey(10);
}
}
return features;
}
/*
@File :utils.inl
@Description: :
@Date :2021/12/25 20:29:26
@Author :xieyin
@version :1.0
*/
template<typename T>
void trainAndTest(string mode){
/*
@description : train a svm model and test its error rate
@param :
mode : machine learning mode
@Returns :
None
*/
vector<float> trainingData;
vector<int> trainResponses;
vector<float> testData;
vector<int> testResponses;
int numTest = 20;
string nutPath = "data/nut";
string ringPath = "data/ring";
string screwPath = "data/screw";
// read data path and extract feature
readFolderAndExtractFeatures(nutPath, 0, numTest, trainingData, trainResponses, testData, testResponses);
readFolderAndExtractFeatures(ringPath, 1, numTest, trainingData, trainResponses, testData, testResponses);
readFolderAndExtractFeatures(screwPath, 2, numTest, trainingData, trainResponses, testData, testResponses);
// cout << "Num of train samples: " << trainingData.size() << endl;
// cout << "Num of test samples: " << testData.size() << endl;
Mat trainingDataMat(trainingData.size() / 2, 2, CV_32FC1, &trainingData[0]);
Mat trainResponsesMat(trainResponses.size(), 1, CV_32SC1, &trainResponses[0]);
Mat testDataMat(testData.size() / 2, 2, CV_32FC1, &testData[0]);
Mat testResponsesMat(testResponses.size(), 1, CV_32SC1, &testResponses[0]);
// set row sample
Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, trainResponsesMat);
// select model
Ptr<T> model = T::create();
model->train(tData);
model->save("model/" + mode + ".xml");
if(testResponses.size() > 0){
Mat testPredict;
// predict
model->predict(testDataMat, testPredict);
testPredict.convertTo(testPredict, CV_32SC1);
Mat errMat = testPredict != testResponsesMat;
float error = 100.0f * countNonZero(errMat) / testResponses.size();
cout << mode << " Error rate: " << error << "\%" << endl;
plotData(trainingDataMat, trainResponsesMat, mode, &error);
}
else{
plotData(trainingDataMat, trainResponsesMat, mode);
}
}
template<typename T>
void predict(vector<vector<float>> features, vector<int> posLeft, vector<int> posTop, string mode, Mat& imgOut){
/*
@description : predict features extracted from imgOut, and put text in left top position
@param :
features : extracted feature from imgOut
posLeft : left_top left location
posTop : left_top top location
mode : machine learning mode
imgOut : the img with text output
@Returns :
(ref return) : imgOut
*/
for(int i = 0; i < features.size(); i++){
Mat predDataMat(1, 2, CV_32FC1, &features[i][0]);
Ptr<T> model = Algorithm::load<T>("model/" + mode + ".xml");
float result = model->predict(predDataMat);
cout << result << endl;
stringstream ss;
Scalar color;
if(result == 0){
color = Scalar(255, 0, 0);
ss << "NUT";
}
else if(result == 1){
color = Scalar(0, 255, 0);
ss << "RING";
}
else if(result == 2){
color = Scalar(0, 255, 0);
ss << "SCREW";
}
putText(imgOut, ss.str(), Point2d(posLeft[i], posTop[i]), FONT_HERSHEY_SIMPLEX, 0.4, color);
}
}
3)主函数
/*
@File :main.cpp
@Description: :
@Date :2021/12/25 09:23:30
@Author :xieyin
@version :1.0
*/
#include<iostream>
#include<string>
#include<sstream>
#include<memory>
using namespace std;
#include<opencv2/core.hpp>
#include<opencv2/highgui.hpp>
#include<opencv2/imgproc.hpp>
#include<opencv2/opencv.hpp>
#include<opencv2/ml.hpp>
using namespace cv;
using namespace cv::ml;
#include"mwindow.hpp"
#include"utils.hpp"
const char* keys = {
"{help h usage ? | | Print this message}"
"{@image | | Image for test}"
"{@lightPat | | light pattern for test image}"
"{@mode | svm | machine learning mode, default svm}"
};
shared_ptr<MWindow> myWin;
int main(int argc, const char** argv){
// command line parser
CommandLineParser parser(argc, argv, keys);
if(parser.has("help")){
parser.printMessage();
return 0;
}
if(!parser.check()){
parser.printErrors();
return 0;
}
// define mywin
myWin = make_shared<MWindow>("Main Window", 2, 2, 700, 1000, 1);
// get test image path
String imgFile = parser.get<String>(0);
Mat img = imread(imgFile, 0);
if(img.data == NULL){
cout << "can not read image file." << endl;
return 0;
}
// get light pattern image
String ligPatFile = parser.get<String>(1);
Mat lightPat = imread(ligPatFile, 0);
if(lightPat.data == NULL){
cout << "can not read image file." << endl;
return 0;
}
// mdeianblur light pattern
medianBlur(lightPat, lightPat, 3);
// copy img to imgOut
Mat imgOut = img.clone();
cvtColor(imgOut, imgOut, COLOR_GRAY2BGR);
// preprocess image
Mat pre = preProcess(img);
// get feature and top left location from image
vector<int> posLeft, posTop;
vector<vector<float>> features = extractFeatures(pre, &posLeft, &posTop);
// get mode selection
string mode = parser.get<string>(2);
// train and predict model
if (mode == "svm"){
trainSVM();
// trainAndTest<SVM>(mode);
predict<SVM>(features, posLeft, posTop, mode, imgOut);
}
else if (mode == "bayes"){
trainAndTest<NormalBayesClassifier>(mode);
predict<NormalBayesClassifier>(features, posLeft, posTop, mode, imgOut);
}
else if(mode == "boost"){
trainAndTest<Boost>(mode);
predict<Boost>(features, posLeft, posTop, mode, imgOut);
}
else{
cout << "not support model";
return 0;
}
myWin->addImage("binary Image", pre);
myWin->addImage("result", imgOut);
myWin->render();
waitKey(0);
return 0;
}