训练分类器
参考:
OpenCV3 Java 机器学习使用方法汇总
OpenCV 之 神经网络 (一)
package train;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.TermCriteria;
import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Ml;
import org.opencv.ml.TrainData;
public class Study_02 {
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
public static void run() {
// 训练数据,两个维度,表示身高和体重
float[] trainingData = { 186, 80, 185, 81, 160, 50, 161, 48 };
// 训练标签数据,前两个表示男生0,后两个表示女生1,由于使用了多种机器学习算法,他们的输入有些不一样,所以labelsMat有三种
float[] labels = { 0f, 0f, 0f, 0f, 1f, 1f, 1f, 1f };
int[] labels2 = { 0, 0, 1, 1 };
float[] labels3 = { 0, 0, 1, 1 };
// 测试数据,先男后女
float[] test = { 184, 79, 159, 50 };
Mat trainingDataMat = new Mat(4, 2, CvType.CV_32FC1);
trainingDataMat.put(0, 0, trainingData);
Mat labelsMat = new Mat(4, 2, CvType.CV_32FC1);
labelsMat.put(0, 0, labels);
Mat labelsMat2 = new Mat(4, 1, CvType.CV_32SC1);
labelsMat2.put(0, 0, labels2);
Mat labelsMat3 = new Mat(4, 1, CvType.CV_32FC1);
labelsMat3.put(0, 0, labels3);
Mat sampleMat = new Mat(2, 2, CvType.CV_32FC1);
sampleMat.put(0, 0, test);
MyAnn(trainingDataMat, labelsMat, sampleMat);
}
// 人工神经网络
public static Mat MyAnn(Mat trainingData, Mat labels, Mat testData) {
/*
* TrainData create(Mat samples, int layout, Mat responses)
* - samples:矩阵样本,需要 CV_32F 类型
* - labels: ROW_SAMPLE 行样本,COL_SAMPLE 列样本
* - responses:对应样本的分类结果。
*/
TrainData td = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);
Mat layerSizes = new Mat(1, 4, CvType.CV_32FC1);
// 含有两个隐含层的网络结构,输入、输出层各两个节点,每个隐含层含两个节点
layerSizes.put(0, 0, new float[] { 2, 2, 2, 2 });
ANN_MLP ann = ANN_MLP.create(); //创建空模型
ann.setLayerSizes(layerSizes); //设置神经网络的层数和神经元数量
/* setTrainMethod(): 设置训练方法。BACKPROP,RPROP,ANNEAL,默认是RPROP
* setActivationFunction(type, param1, param2) 设置激活函数。
* - type: 目前只支持ANN_MLP.SIGMOID_SYM。
* - param1: 对应 setBackpropWeightScale()中的参数。
* - param2: 对应 setBackpropMomentumScale()中的参数。
*/
ann.setTrainMethod(ANN_MLP.BACKPROP);
ann.setBackpropWeightScale(0.1); //默认值0.1
ann.setBackpropMomentumScale(0.1);//默认值 0.1
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1);
//设置迭代终止准则,默认为TermCriteria(TermCriteria.MAX_ITER + TermCriteria.EPS, 1000, 0.01)
ann.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER + TermCriteria.EPS, 300, 0.0));
boolean success = ann.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());
System.out.println("Ann training result: " + success);
// ann.save("D:/bp.xml");//存储模型
// ann.load("D:/bp.xml");//读取模型
// 测试数据
Mat responseMat = new Mat();
ann.predict(testData, responseMat, 0);
System.out.println("Ann responseMat:\n" + responseMat.dump());
for (int i = 0; i < responseMat.size().height; i++) {
if (responseMat.get(i, 0)[0] + responseMat.get(i, i)[0] >= 1)
System.out.println("Girl\n");
if (responseMat.get(i, 0)[0] + responseMat.get(i, i)[0] < 1)
System.out.println("Boy\n");
}
return responseMat;
}
public static void main(String[] args) {
run();
}
}