logistic回归 java_OpenCV3.3中逻辑回归(Logistic Regression)使用举例

OpenCV3.3中给出了逻辑回归(logistic regression)的实现,即cv::ml::LogisticRegression类,类的声明在include/opencv2/ml.hpp文件中,实现在modules/ml/src/lr.cpp文件中,它既支持两分类,也支持多分类,其中:

(1)、cv::ml::LogisticRegression类继承自cv::ml::StateModel,而cv::ml::StateModel又继承自cv::Algorithm;

(2)、setLearningRate函数用来设置学习率,getLearningRate函数用来获取学习率值;

(3)、setIterations函数用来设置迭代次数,getIterations函数用来获取迭代次数值;

(4)、setRegularization函数用来设置采用哪种正则化方法,目前支持两种L1 norm和L2 norm,正则化方法主要用来防止过拟合,getRegularization函数用来获取采用哪种正则化方法;

(5)、setTrainMethod函数用来设置采用哪种训练方法,目前支持两种Batch和Mini-Batch, getTrainMethod函数用来获取采用哪种训练方法;

(6)、setMiniBatchSize函数用来设置在Mini-Batch梯度下降训练方法中每一个step采集的训练样本数,getMiniBatchSize函数用来获取每一个step采集的训练样本数;

(7)、setTermCriteria函数用来设置终止训练的条件,包括迭代次数和期望的精度,getTermCriteria用来获取终止训练的条件;

(8)、get_learnt_thetas函数用来获取训练参数;

(9)、create函数为static, new一个LogisticRegressionImpl用来创建一个LogisticRegression对象;

(10)、train函数(使用基类StatModel中的)进行训练;

(11)、predict函数用于预测;

(12)、save函数(使用基类Algorithm中的)保存已训练好的model,支持xml,yaml,json格式;

(13)、load函数用来load已训练好的model;

以下为两分类测试代码:训练数据集为从MNIST中train中随机选取的0、1各10个图像;测试数据集为从MNIST中test中随机选取的0、1各10个图像,如下图,其中第一排前10个0用于训练,后10个0用于测试;第二排前10个1用于训练,后10个1用于测试:

c0a57cb025d5662f984161aa71e2230a.png

#include "opencv.hpp" #include #include #include #include #include #include "common.hpp" // Logistic Regression /// static void show_image(const cv::Mat& data, int columns, const std::string& name) { cv::Mat big_image; for (int i = 0; i < data.rows; ++i) { big_image.push_back(data.row(i).reshape(0, columns)); } cv::imshow(name, big_image); cv::waitKey(0); } static float calculate_accuracy_percent(const cv::Mat& original, const cv::Mat& predicted) { return 100 * (float)cv::countNonZero(original == predicted) / predicted.rows; } int test_opencv_logistic_regression_train() { const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" }; cv::Mat data, labels, result; for (int i = 1; i < 11; ++i) { const std::vector<:string> label{ "0_", "1_" }; for (const auto& value : label) { std::string name = std::to_string(i); name = image_path + value + name + ".jpg"; cv::Mat image = cv::imread(name, 0); if (image.empty()) { fprintf(stderr, "read image fail: %sn", name.c_str()); return -1; } data.push_back(image.reshape(0, 1)); } } data.convertTo(data, CV_32F); //show_image(data, 28, "train data"); std::unique_ptr tmp(new float[20]); for (int i = 0; i < 20; ++i) { if (i % 2 == 0) tmp[i] = 0.f; else tmp[i] = 1.f; } labels = cv::Mat(20, 1, CV_32FC1, tmp.get()); cv::Ptr<:ml::logisticregression> lr = cv::ml::LogisticRegression::create(); lr->setLearningRate(0.00001); lr->setIterations(100); lr->setRegularization(cv::ml::LogisticRegression::REG_DISABLE); lr->setTrainMethod(cv::ml::LogisticRegression::MINI_BATCH); lr->setMiniBatchSize(1); CHECK(lr->train(data, cv::ml::ROW_SAMPLE, labels)); const std::string save_file{ "E:/GitCode/NN_Test/data/logistic_regression_model.xml" }; // .xml, .yaml, .jsons lr->save(save_file); return 0; } int test_opencv_logistic_regression_predict() { const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" }; cv::Mat data, labels, result; for (int i = 11; i < 21; ++i) { const std::vector<:string> label{ "0_", "1_" }; for (const auto& value : label) { std::string name = std::to_string(i); name = image_path + value + name + ".jpg"; cv::Mat image = cv::imread(name, 0); if (image.empty()) { fprintf(stderr, "read image fail: %sn", name.c_str()); return -1; } data.push_back(image.reshape(0, 1)); } } data.convertTo(data, CV_32F); //show_image(data, 28, "test data"); std::unique_ptr tmp(new int[20]); for (int i = 0; i < 20; ++i) { if (i % 2 == 0) tmp[i] = 0; else tmp[i] = 1; } labels = cv::Mat(20, 1, CV_32SC1, tmp.get()); const std::string model_file{ "E:/GitCode/NN_Test/data/logistic_regression_model.xml" }; cv::Ptr<:ml::logisticregression> lr = cv::ml::LogisticRegression::load(model_file); lr->predict(data, result); fprintf(stdout, "predict result: n"); std::cout << "actual: " << labels.t() << std::endl; std::cout << "target: " << result.t() << std::endl; fprintf(stdout, "accuracy: %.2f%%n", calculate_accuracy_percent(labels, result)); return 0; }

测试代码中,test_opencv_logistic_regression_train函数用于训练,训练结果会产生一个叫logistic_regression_model.xml的model文件;test_opencv_logistic_regression_predict函数用于预测,预测结果如下,由结果可知,预测全部正确:

1aecc811bf2519398e15ff2303d9c6ce.png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值