【图像处理】-035 knn分类算法实现数字识别
在上一篇文章中,我简单的介绍了一下KNN分类算法的原理。其基本原理是计算待分类样本与训练样本之间的距离 d d d,选择与之最近的K个样本中最近的样本类别或者最多的类别作为待分类样本的类别。
1 概述
KNN分类算法没有传统意义上的训练过程,其训练过程只是简单的将训练样本和标签保存起来,用于分类时的距离计算。因此,对于小样本量、低维度样本,KNN分类算法效率较高,易于使用。但是对于大样本量或者高维度数据,由于每次需要计算待分类样本与所有训练样本的距离,导致KNN分类算法在处理这类问题时效率较低。
2 印刷体数字识别问题
这里讨论的是印刷体数字识别问题。对于印刷体数字,其字形是0-9之间10个字符,并且同一个字符在不同字号时是放大缩小的关系。因此,印刷体数字识别的分类类别是有限的,明确的。
对于手写体数字,目前已经是深度学习的入门必学实验,MNIST库也被用得烂大街了。下次有时间再来讨论手写体数字的识别问题。
由于印刷体数字的字形是有限的,那么一个常规的思路就是把各个字符的模板制作出来,然后与经过字符分割的字符图像进行模板匹配,从而识别字符。这个问题的限制条件在于同一个字符的尺寸可能是不一样的,该方法对字符分割的精度要求太高。
3 KNN印刷数字分类
对于使用KNN分类算法识别印刷数字的问题,可以将问题转化为待分类字符与训练样本集之间距离最近的样本的选择问题。
3.1 训练样本的制作
KNN分类算法需要的训练样本的数量至少需要达到期望分的类每个类至少一个样本。对于同一个类,样本量没有限制,但样本量的大小会限制分类速度。
这里,我的训练样本来自于从图像中截取的数字字符小块,并通过字符分割、二值化将数字字符保存成二值图像,从而消除字符印刷颜色对分类模型本身的影响,缩小模型体积。
3.2 分类器模型的建立
void TrainKNNModel()
{
//使用opencv自带的方式创建KNN分类模型
cv::Ptr<cv::ml::KNearest> knn = cv::ml::KNearest::create();
knn->setDefaultK(1);//设置k为1,表示找最近的一个作为分类,k与训练样本中每个类的样本数量有关。
knn->setIsClassifier(true);//设置成分类
knn->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE);
const int train_samples_number{ 100 };//训练样本总数
const int every_class_number{ 10 };//训练样本的标签数目
cv::Mat train_data(train_samples_number, 14 * 16, CV_32FC1);
cv::Mat train_labels(train_samples_number, 1, CV_32FC1);
char buf[512] = { 0 };
for (int i = 0; i < 10; i++)
{
for (int j = 0; j < 10; j++)
{
sprintf_s(buf, "./res/digs/%d/%d.jpg", i, j);//读取每个类别的训练样本
cv::Mat temp = cv::imread(buf, cv::IMREAD_GRAYSCALE);
cv::resize(temp, temp, cv::Size(14, 16));
temp.convertTo(temp, CV_32FC1);
temp = temp.reshape(0, 1);
cv::Mat tmp = train_data.row(i * 10 + j);
temp.copyTo(tmp);//保存训练样本
train_labels.at<float>(i * 10 + j, 0) = i;//设置训练样本的标签
}
}
knn->train(train_data, cv::ml::ROW_SAMPLE, train_labels);//模型训练
knn->save("./res/digs/knn.xml");//训练模型的保存,以便下次直接加载该模型。
}
3.3 进行分类
//模型加载,加载之前完成训练的模型
cv::Ptr<cv::ml::KNearest> knnModel = NULL;
bool G_bInited = false;
cv::Ptr<cv::ml::KNearest> GetKnnModel()
{
if (G_bInited == false)
{
knnModel = cv::ml::StatModel::load<cv::ml::KNearest>("./res/digs/knn.xml");
G_bInited = true;
}
return knnModel;
}
//数字识别
int RecoDigital(cv::Mat& dig)
{
cv::Mat temp;
//这里,将待识别的字符归一化到了训练样本的大小
cv::resize(dig, temp, cv::Size(14, 16));
temp.convertTo(temp, CV_32FC1);
temp = temp.reshape(0, 1);
cv::Mat matResults(0, 0, CV_32F);//保存测试结果
//KNN分类
GetKnnModel()->findNearest(temp, 1, matResults);//knn分类预测
float v = matResults.at<float>(0, 0);
return int(v);//获取分类结果
}