Java OpenCV 人工智能01.0 机器学习 支持向量机 SVM
OpenCV 中的机器学习(Machine Learning,ML)算法。
序号 | 中文名称 | 英文名称 |
---|---|---|
01 | 决策树 | Decision Tree |
02 | EM算法 | Expectation - Maximization |
03 | 贝叶斯分类 | Normal Bayes Classifier |
04 | K-邻近算法 | K-Nearest Neighbour Classifier |
05 | 支持向量机 | Support Vector Machine |
06 | Boost树算法 | Boosted Tree Classifier |
07 | 随机森林算法 | Random Trees Classifier |
08 | 人工神经网络 | Artificial Neural Networks |
09 | 梯度Boost树算法 | Gradient Boosted Trees |
10 | 绝对随机森林算法 | Extremely Randomized Trees Classifier |
11 | SGD支持向量机 | Stochastic Gradient Descent Support Vector Machine |
package com.xu.opencv;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.TermCriteria;
import org.opencv.ml.Ml;
import org.opencv.ml.SVM;
import org.opencv.ml.TrainData;
/**
*
* @Author: hyacinth
* @Title: ML.java
* @Package com.xu.opencv
* @Description:
* @Date: 2019年11月22日20:23:23
* @Version V-1.0
* @Copyright: 2019 hyacinth
*
*/
public class ML {
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
public static void main(String[] args) {
// 训练数据 体重,身高
float[] train_data = { 186, 80, 185, 81, 165, 60, 168, 61, 160, 50, 161, 48 };
// 测试数据 身高,体重
float[] test = { 184, 79, 160, 62, 159, 50 };
// 训练数据 0=男 ,1=女
int[] label = { 0, 0, 0, 0, 1, 1 };
Mat train_mat = new Mat(6, 2, CvType.CV_32FC1);
train_mat.put(0, 0, train_data);
Mat train_lable = new Mat(6, 1, CvType.CV_32SC1);
train_lable.put(0, 0, label);
Mat test_lable = new Mat(3, 2, CvType.CV_32FC1);
test_lable.put(0, 0, test);
SVM(train_mat,train_lable,test_lable);
}
/**
* OpenCV-4.1.0 SVM 支持向量机
* @Author: hyacinth
* @Title: SVM
* @param : tarin 训练数据
* @param : lable 训练标签
* @param : test 测试数据
* @Description: TODO
* @return void
* @date: 2019年11月22日20:23:23
*/
public static void SVM(Mat tarin,Mat lable,Mat test) {
SVM svm = SVM.create();
svm.setC(1);
svm.setP(0);
svm.setNu(0);
svm.setCoef0(0);
svm.setGamma(1);
svm.setDegree(0);
svm.setType(SVM.C_SVC);
svm.setKernel(SVM.LINEAR);
TermCriteria criteria=new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER,1000,0);
svm.setTermCriteria(criteria);
TrainData trainData = TrainData.create(tarin,Ml.ROW_SAMPLE,lable);
svm.train(trainData.getSamples(), Ml.ROW_SAMPLE,trainData.getResponses());
svm.save("C:\\Users\\Administrator\\Desktop\\number.xml");
Mat response = new Mat();
svm.predict(test, response, 0);
System.out.println(response.dump());
for (int i = 0; i < response.height(); i++) {
if (response.get(i, 0)[0] == 0)
System.out.println("男");
if (response.get(i, 0)[0] == 1)
System.out.println("女");
}
}
}