#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>
#include <iostream>
#include <string>
using namespace std;
using namespace cv;
int main()
{
//初始化
Ptr<ml::ANN_MLP> bp = ml::ANN_MLP::create();
const int sampleNum = 6; // 训练的总样本数
const int featuresNum = 5; // 单个样本的特征向量维数
const int classesNum = 3; // 类别个数
//设置层数
Mat layerSizes = (Mat_<int>(1, 5) << featuresNum, 2, 2, 2, classesNum);
bp->setLayerSizes(layerSizes);
//设置参数
bp->setTrainMethod(ml::ANN_MLP::BACKPROP, 0.1, 0.1);
bp->setActivationFunction(ml::ANN_MLP::SIGMOID_SYM);
bp->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 10000, /*FLT_EPSILON*/1e-6));
float trainingData[sampleNum][featuresNum] = { { 1, 2, 3,4,5 }, { 4, 5, 6,7,8 }, { 111, 112, 113,114,115 }, { 124, 125, 126,120,121 }, { 21, 22, 23,24,25}, { 24, 25, 26,27,28 } };
Mat trainingDataMat(sampleNum, featuresNum, CV_32FC1, trainingData);
float labels[sampleNum][classesNum] = { { 1, 0, 0 }, { 1, 0, 0 }, { 0, 1, 0 }, { 0, 1, 0 }, { 0, 0, 1 }, { 0, 0, 1 } };
Mat labelsMat(sampleNum, classesNum, CV_32FC1, labels);
// 训练数据
bool trained = bp->train(trainingDataMat, ml::ROW_SAMPLE, labelsMat);
if (trained)
{
bp->save("bp_param");
}
// 测试
int width = 200, height = 200;
Mat image = Mat::zeros(height, width, CV_8UC3);
Vec3b green(0, 255, 0), blue(255, 0, 0),red(0,0,255),y1(255,255,0),y2(255,0,255),y3(0,255,255);
// Show the decision regions given by the SVM
for (int i = 0; i < image.rows; ++i)
{
for (int j = 0; j < image.cols; ++j)
{
Mat sampleMat = (Mat_<float>(1, 5) << i, j, 0,1,2);
Mat responseMat;
bp->predict(sampleMat, responseMat);
float* p = responseMat.ptr<float>(0);
float response = 0.0f;
float maxV = p[0];
int index = 0;
//找出最大值作为最终的类别号
for (int k = 1; k < classesNum; k++)
{
if (p[k] > maxV)
{
maxV = p[k];
index = k;
}
}
cout << index << " ";
switch (index)
{
case 0:
image.at<Vec3b>(j, i) = green;
break;
case 1:
image.at<Vec3b>(j, i) = blue;
break;
case 2:
image.at<Vec3b>(j, i) = red;
break;
default:
break;
}
}
}
imshow("BP Simple Example", image);
waitKey(0);
}
效果: