本博客属于原创,若转载请标明转载出处:https://blog.csdn.net/qq_44091004/article/details/90573495
小编在开发的过程中就省去了ANN的原理部分,所以小编认为大家是有一定基础的。本案例是ANN-预测手写数字的案例。
本文需要的开发环境是新建一个工程(new project),这个工程里已经配置了OpenCV开发环境。
第一步:首先就是数据集材料的从哪里来,小编这里有一段程序是可以直接运行的App,作用是通过一个面板手绘数字,然后通过缩小的功能将图片都缩小为12*12的尺寸,然后存在自己指定的盘中,比如小编这里将手写图片数据都存在E://opencv3.1//samples//ocr中。从下图中可以看到,我的源数据集从何而来。
第二步:建立手写数字的资料库。
该数据库里包含五十个训练数据,每个数字包含五个手写体,还有一个测试数据sample7。下面的代码中有一些需要注意的地方,首先是将图片这个二维矩阵数据转换为行样本存储在Mat中,ANN中每个样本的标签用一个1行10列的矩阵来存储。
package com;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.imgcodecs.Imgcodecs;
public class OcrDatabase {
static {System.loadLibrary(Core.NATIVE_LIBRARY_NAME);};
Mat trainingLabelsFloatMat = new Mat(50,10,CvType.CV_32FC1);//创建训练数据集的标签存储矩阵
Mat trianingDataMat = new Mat(50,144,CvType.CV_32FC1) ;//创建训练数据集的存储矩阵
Mat testingDataMat = new Mat(10,144,CvType.CV_32FC1);//创建测试数据集的存储矩阵
float[] testingLabels = {0,1,2,3,4,5,6,7,8,9};
float[][] trainingLabelsFloat = {
{1,0,0,0,0,0,0,0,0,0},
{1,0,0,0,0,0,0,0,0,0},
{1,0,0,0,0,0,0,0,0,0},
{1,0,0,0,0,0,0,0,0,0},
{1,0,0,0,0,0,0,0,0,0},
{0,1,0,0,0,0,0,0,0,0},
{0,1,0,0,0,0,0,0,0,0},
{0,1,0,0,0,0,0,0,0,0},
{0,1,0,0,0,0,0,0,0,0},
{0,1,0,0,0,0,0,0,0,0},
{0,0,1,0,0,0,0,0,0,0},
{0,0,1,0,0,0,0,0,0,0},
{0,0,1,0,0,0,0,0,0,0},
{0,0,1,0,0,0,0,0,0,0},
{0,0,1,0,0,0,0,0,0,0},
{0,0,0,1,0,0,0,0,0,0},
{0,0,0,1,0,0,0,0,0,0},
{0,0,0,1,0,0,0,0,0,0},
{0,0,0,1,0,0,0,0,0,0},
{0,0,0,1,0,0,0,0,0,0},
{0,0,0,0,1,0,0,0,0,0},
{0,0,0,0,1,0,0,0,0,0},
{0,0,0,0,1,0,0,0,0,0},
{0,0,0,0,1,0,0,0,0,0},
{0,0,0,0,1,0,0,0,0,0},
{0,0,0,0,0,1,0,0,0,0},
{0,0,0,0,0,1,0,0,0,0},
{0,0,0,0,0,1,0,0,0,0},
{0,0,0,0,0,1,0,0,0,0},
{0,0,0,0,0,1,0,0,0,0},
{0,0,0,0,0,0,1,0,0,0},
{0,0,0,0,0,0,1,0,0,0},
{0,0,0,0,0,0,1,0,0,0},
{0,0,0,0,0,0,1,0,0,0},
{0,0,0,0,0,0,1,0,0,0},
{0,0,0,0,0,0,0,1,0,0},
{0,0,0,0,0,0,0,1,0,0},
{0,0,0,0,0,0,0,1,0,0},
{0,0,0,0,0,0,0,1,0,0},
{0,0,0,0,0,0,0,1,0,0},
{0,0,0,0,0,0,0,0,1,0},
{0,0,0,0,0,0,0,0,1,0},
{0,0,0,0,0,0,0,0,1,0},
{0,0,0,0,0,0,0,0,1,0},
{0,0,0,0,0,0,0,0,1,0},
{0,0,0,0,0,0,0,0,0,1},
{0,0,0,0,0,0,0,0,0,1},
{0,0,0,0,0,0,0,0,0,1},
{0,0,0,0,0,0,0,0,0,1},
{0,0,0,0,0,0,0,0,0,1}
};
Mat sample7=new Mat(1,144,CvType.CV_32FC1);
public OcrDatabase() {
Mat source;
//assign training Mat
for(int i = 0;i < 50;i++) {
if(i < 10) {
source = Imgcodecs.imread("E://opencv3.1//samples//ocr//0"+i+".jpg",Imgcodecs.CV_LOAD_IMAGE_GRAYSCALE);
}
else{
source = Imgcodecs.imread("E://opencv3.1//samples//ocr//"+i+".jpg",Imgcodecs.CV_LOAD_IMAGE_GRAYSCALE);
}
Mat temp = source.reshape(1, 144);
for(int j = 0;j < 144;j++) {
double[] data = new double[1];
data = temp.get(j, 0);
trianingDataMat.put(i, j, data);
}
trainingLabelsFloatMat.put(i, 0, trainingLabelsFloat[i]);
}
Mat sample = Imgcodecs.imread("E://opencv3.1//samples//ocr//number71.jpg",Imgcodecs.CV_LOAD_IMAGE_GRAYSCALE);
Mat tempSample7=sample.reshape(1,144);
for(int j=0;j<144;j++){
double[] data=new double[1];
data=tempSample7.get(j, 0);
sample7.put(0, j, data);
}
}
public Mat getTrainingDataMat() {
return trianingDataMat;
}
public void setTrainingDataMat(Mat trainingDataMat) {
this.trianingDataMat = trainingDataMat;
}
public Mat getTrainingLabelsFloatMat() {
return trainingLabelsFloatMat;
}
public void setTrainingLabelsFloatMat(Mat trainingLabelsFloatMat) {
this.trainingLabelsFloatMat = trainingLabelsFloatMat;
}
}
第三步:测试类
import org.opencv.core.Core;
import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Ml;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.TermCriteria;
public class KNN {
static{ System.loadLibrary(Core.NATIVE_LIBRARY_NAME); }
public static void main(String[] args) {
OcrDatabase ocr=new OcrDatabase();//创建数据集对象,相当于做试验前准备好的实验数据,封装了起来,以便调用
ANN_MLP ann=ANN_MLP.create();//创建ANN
//设置网络的模型,首先创建网络模型对象,然后创建数组,将数组中的数据存储到Mat对象中,分别包含输入层,隐含层和输出层,其中输入
//层为144个特征,隐含层包括两层,分别有20和10个神经元,输出层为10个数据
Mat layerSize=new Mat(4,1,CvType.CV_32SC1);
int[] layerSizeAry={144, 20,10, 10};
layerSize.put(0,0,layerSizeAry[0]);
layerSize.put(1,0,layerSizeAry[1]);
layerSize.put(2,0,layerSizeAry[2]);
layerSize.put(3,0,layerSizeAry[3]);
ann.setLayerSizes(layerSize);
ann.setTrainMethod(ann.BACKPROP);//设置训练方法:误差反向传播
TermCriteria criteria=new TermCriteria(TermCriteria.MAX_ITER|TermCriteria.EPS, 300, 0.001);//创建对象,设置标准属性包括最大迭代次数和误差率
ann.setTermCriteria(criteria);
ann.setActivationFunction(ann.SIGMOID_SYM);//设置激励函数为SIGMOLD类型
boolean r=ann.train(ocr.getTrainingDataMat(), Ml.ROW_SAMPLE, ocr.getTrainingLabelsFloatMat());//训练函数,样本类型为行样本,然后获取训练数据集和标签进行训练
System.out.println("是否有训练成功="+r);
//测试sample7
float result7= ann.predict(ocr.sample7);
System.out.println("预测7结果="+result7);
}
}
运行结果如下图所示:
最后,作者参考的书籍:
《opencv3 使用java开发手册》