OpenCV3.3中给出了逻辑回归(logistic regression)的实现,即cv::ml::LogisticRegression类,类的声明在include/opencv2/ml.hpp文件中,实现在modules/ml/src/lr.cpp文件中,它既支持两分类,也支持多分类,其中:
(1)、cv::ml::LogisticRegression类继承自cv::ml::StateModel,而cv::ml::StateModel又继承自cv::Algorithm;
(2)、setLearningRate函数用来设置学习率,getLearningRate函数用来获取学习率值;
(3)、setIterations函数用来设置迭代次数,getIterations函数用来获取迭代次数值;
(4)、setRegularization函数用来设置采用哪种正则化方法,目前支持两种L1 norm和L2 norm,正则化方法主要用来防止过拟合,getRegularization函数用来获取采用哪种正则化方法;
(5)、setTrainMethod函数用来设置采用哪种训练方法,目前支持两种Batch和Mini-Batch, getTrainMethod函数用来获取采用哪种训练方法;
(6)、setMiniBatchSize函数用来设置在Mini-Batch梯度下降训练方法中每一个step采集的训练样本数,getMiniBatchSize函数用来获取每一个step采集的训练样本数;
(7)、setTermCriteria函数用来设置终止训练的条件,包括迭代次数和期望的精度,getTermCriteria用来获取终止训练的条件;
(8)、get_learnt_thetas函数用来获取训练参数;
(9)、create函数为static, new一个LogisticRegressionImpl用来创建一个LogisticRegression对象;
(10)、train函数(使用基类StatModel中的)进行训练;
(11)、predict函数用于预测;
(12)、save函数(使用基类Algorithm中的)保存已训练好的model,支持xml,yaml,json格式;
(13)、load函数用来load已训练好的model;
以下为两分类测试代码:训练数据集为从MNIST中train中随机选取的0、1各10个图像;测试数据集为从MNIST中test中随机选取的0、1各10个图像,如下图,其中第一排前10个0用于训练,后10个0用于测试;第二排前10个1用于训练,后10个1用于测试:
#include "opencv.hpp"
#include <string>
#include <vector>
#include <memory>
#include <opencv2/opencv.hpp>
#include <opencv2/ml.hpp>
#include "common.hpp"
// Logistic Regression ///
static void show_image(const cv::Mat& data, int columns, const std::string& name)
{
cv::Mat big_image;
for (int i = 0; i < data.rows; ++i) {
big_image.push_back(data.row(i).reshape(0, columns));
}
cv::imshow(name, big_image);
cv::waitKey(0);
}
static float calculate_accuracy_percent(const cv::Mat& original, const cv::Mat& predicted)
{
return 100 * (float)cv::countNonZero(original == predicted) / predicted.rows;
}
int test_opencv_logistic_regression_train()
{
const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" };
cv::Mat data, labels, result;
for (int i = 1; i < 11; ++i) {
const std::vector<std::string> label{ "0_", "1_" };
for (const auto& value : label) {
std::string name = std::to_string(i);
name = image_path + value + name + ".jpg";
cv::Mat image = cv::imread(name, 0);
if (image.empty()) {
fprintf(stderr, "read image fail: %s\n", name.c_str());
return -1;
}
data.push_back(image.reshape(0, 1));
}
}
data.convertTo(data, CV_32F);
//show_image(data, 28, "train data");
std::unique_ptr<float[]> tmp(new float[20]);
for (int i = 0; i < 20; ++i) {
if (i % 2 == 0) tmp[i] = 0.f;
else tmp[i] = 1.f;
}
labels = cv::Mat(20, 1, CV_32FC1, tmp.get());
cv::Ptr<cv::ml::LogisticRegression> lr = cv::ml::LogisticRegression::create();
lr->setLearningRate(0.00001);
lr->setIterations(100);
lr->setRegularization(cv::ml::LogisticRegression::REG_DISABLE);
lr->setTrainMethod(cv::ml::LogisticRegression::MINI_BATCH);
lr->setMiniBatchSize(1);
CHECK(lr->train(data, cv::ml::ROW_SAMPLE, labels));
const std::string save_file{ "E:/GitCode/NN_Test/data/logistic_regression_model.xml" }; // .xml, .yaml, .jsons
lr->save(save_file);
return 0;
}
int test_opencv_logistic_regression_predict()
{
const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" };
cv::Mat data, labels, result;
for (int i = 11; i < 21; ++i) {
const std::vector<std::string> label{ "0_", "1_" };
for (const auto& value : label) {
std::string name = std::to_string(i);
name = image_path + value + name + ".jpg";
cv::Mat image = cv::imread(name, 0);
if (image.empty()) {
fprintf(stderr, "read image fail: %s\n", name.c_str());
return -1;
}
data.push_back(image.reshape(0, 1));
}
}
data.convertTo(data, CV_32F);
//show_image(data, 28, "test data");
std::unique_ptr<int[]> tmp(new int[20]);
for (int i = 0; i < 20; ++i) {
if (i % 2 == 0) tmp[i] = 0;
else tmp[i] = 1;
}
labels = cv::Mat(20, 1, CV_32SC1, tmp.get());
const std::string model_file{ "E:/GitCode/NN_Test/data/logistic_regression_model.xml" };
cv::Ptr<cv::ml::LogisticRegression> lr = cv::ml::LogisticRegression::load(model_file);
lr->predict(data, result);
fprintf(stdout, "predict result: \n");
std::cout << "actual: " << labels.t() << std::endl;
std::cout << "target: " << result.t() << std::endl;
fprintf(stdout, "accuracy: %.2f%%\n", calculate_accuracy_percent(labels, result));
return 0;
}
测试代码中,test_opencv_logistic_regression_train函数用于训练,训练结果会产生一个叫logistic_regression_model.xml的model文件;test_opencv_logistic_regression_predict函数用于预测,预测结果如下,由结果可知,预测全部正确: