【SVM理论到实践4】基于OpenCv中的SVM的手写体数字识别

//由于本人每天时间非常紧张,所以细节写的不详细,博客仅供各位参考,里面的代码都是运行过的,直接可以运行

本章的学习目标:

     1)手写体数字识别数据库MNIST

     2)基于SVM训练的具体步骤  

1)手写体数字识别数据库MNIST

MNIST(Mixed National Institute of Standards and Technology)是一个大型的手写体数字识别数据库,广泛应用于机器学习领域的训练和测试,由纽约大学Yann LeCun教授整理。MNIST包括60000个训练集和10000个测试集,每张图都已经进行了尺寸归一化,数字居中处理,固定为28*28像素。具体的下载地址如下所示:

mnist数据下载:http://yann.lecun.com/exdb/mnist/ 

2)基于SVM训练的具体步骤

训练的过程如下所示:

 1)读取Mnist训练集数据

 2)训练

 3)读取Mnist训练集数据,对比预测结果,得到错误率

3)具体的实现如下所示:

  1)mnist给出的数据文件是二进制文件,四个文件解压之后的情况如下所示:

      1)”train-images.idx3-ubyte”二进制文件,存储了头文件信息以及60000张28*28分辨率的图像信息(用于训练)

      2)”train-labels.idx1-ubyte”二进制文件,存储了文件头信息以及60000张label信息

      3)”t10k-images.idx-ubyte”二进制文件,存储了文件头信息以及10000张28*28分辨率的图想信息(用于测试)

      4)“t10k-labels.idx-ubyte”二进制文件,存储了头文件信息以及10000张图像label信息

  2)因为OpenCv中没有直接导入MNIST数据的文件,所以需要自己写函数来读取MNIST的数据文件

       1)首先,要知道MNIST数据的数据格式:IMAGE_FILE---包含四个int型的头部数据(magic_number,number_of_images,

             number_of_rows,number_of_columns)

       2)余下的每一个byte表示一个pixel的数据,范围是0~255(可以在读入的时候scale到0~1的区间)

       3)LABEL_FILE包含两个型的头部数据(magic_number,number_of_items),余下的每一个byte表示一个label数据,范围是0~9

             我们可以参考下图所示,更加具体的信息可以去MNIST官网了解:

            

   3)此块要注意的第一个坑是:MNIST是大端存储,然而大部分的intel处理器都是小端存储,所以对于int、long、float这些多字节的数据类型,就要一个一个byte地翻转过来,才能正确的显示。

   4)此块注意的第二个坑是:如果用第一条打开文件,不会报错,但是数据会出现错误,头部数据仍然正确,但是后面的pixel数据大部分都是0

    不能用ifstream file(fileName);

   而要改成ifstream file(fileName, ios::binary);

   5)此块注意的第三个坑是:training时,IMAGE和LABEL的数据分别都放进一个MAT中存储,但是只能是CV32_F或者是CV32_S的格式,不然会报错,OpenCv文档给出的例子是这样的:(但是predict的时候又会要求label的格式是unsigned int),所以...可以设置data的Mat格式为CV_32FC1,label的Mat格式为CV_32SC1,当然,最好都设置为CV_32FC1

6)顺便地,图像训练数据的转换格式,也就是说,我们都进来的图像数据都是二维的矩阵,但是我们在训练的时候,需要把二维的图像矩阵拉为一维的向量

7)最后,为了验证数据的正确性,一个有效的办法就是输出第一个和最后一个数据和原始图像的数据进行对比

8)还有需要说明的一点是,此处,我们是直接对原始图像进行训练,并没有对任何对图像的任何特征进行提取;我们也可以在图像进行训练之前,先利用Harris,SIFT,SURF,FAST,BRIRF,ORB,HOG这些提取图像的特征,然后再把提取的特征向量组成训练集进行训练。

/***************************************************************************************************** 
文件描述: 
        头文件mnist.h
开发环境: 
        VS2012 + OpenGl(GLUT3.7) + OpenCv2.4.9 + Halcon10.0 
时间地点: 
        陕西师范大学----2017.3.3
作    者: 
        九月 
*****************************************************************************************************/ 
#ifndef MNIST_H
#define MNIST_H
#include<iostream>
#include<string>
#include<fstream>
#include<ctime>
#include<opencv2/core/core.hpp>
#include<opencv2/highgui/highgui.hpp>
#include<opencv2/imgproc/imgproc.hpp>
#include<opencv2/ml/ml.hpp>

using namespace std;
using namespace cv;

int     ReverseInt(int i);                      //[1]大小端存储转换
cv::Mat ReadMnistImage(const string fileName);  //[2]读取Image的数据信息
cv::Mat ReadMnistLabel(const string fileName);  //[3]读取Label数据信息
#endif

/***************************************************************************************************** 
文件描述: 
        头文件mnist.h的实现文件mnist.cpp
开发环境: 
        VS2012 + OpenGl(GLUT3.7) + OpenCv2.4.9 + Halcon10.0 
时间地点: 
        陕西师范大学----2017.3.3
作    者: 
        九月 
*****************************************************************************************************/ 
#include"mnist.h"
#include<ctime>
#include<iostream>

using namespace std;

int  testNum = 10000;
/***************************************************************************************************** 
函数功能:
         大小端存储模式的数据转换
*****************************************************************************************************/
int  ReverseInt(int i)
{
	unsigned char c1;
	unsigned char c2;
	unsigned char c3;
	unsigned char c4;

	c1 = i&255;
	c2 = (i>>8)&255;
	c3 = (i>>16)&255;
	c4 = (i>>24)&255;

	return ((int)c1<<24)+((int)c2<<16)+((int)c3<<8)+c4;
}
/***************************************************************************************************** 
函数功能:
         读取Minst数据库的图像二进制文件
注意问题:
         此块我们需要注意的问题是:当我们从MINIST数据库中读进来图像文件后,我们将读进来的文件存储在
		 dataMat矩阵容器中,这就是我们送给SVM的训练样本;要注意的是,在这个矩阵容器中,矩阵dataMat
		 中的一行,就代表一个样本,就是实际中的一幅图片;
		 我们有多少张图片,我们就有多少训练样本,这个矩阵就有多少行。
*****************************************************************************************************/
cv::Mat ReadMnistImage(const string fileName)
{
	double      constTime;
    std::clock_t startTime;
    std::clock_t endTime;

	int magicNumber    = 0;
	int numberOfImages = 0;
	int nRows          = 0;
	int nCols          = 0;
	
	cv::Mat dataMat;
	std::ifstream file(fileName,ios::binary);

	if(file.is_open())
	{
		std::cout<<"[NOTICE]The set of Images is opened sucessfully!"<<std::endl;
		file.read((char*)&magicNumber,sizeof(magicNumber));
		file.read((char*)&numberOfImages,sizeof(numberOfImages));
		file.read((char*)&nRows,sizeof(nRows));
		file.read((char*)&nCols,sizeof(nCols));

		magicNumber    = ReverseInt(magicNumber);
		numberOfImages = ReverseInt(numberOfImages);
		nRows          = ReverseInt(nRows);
		nCols          = ReverseInt(nCols);

		std::cout<<"[1]magicNumber    = "<<magicNumber<<std::endl;
		std::cout<<"[2]numberOfImages = "<<numberOfImages<<std::endl;
		std::cout<<"[3]nRows          = "<<nRows<<std::endl;
		std::cout<<"[4]nCols          = "<<nCols<<std::endl;

		//输出第一张和最后一张图片,检查数据无误
		cv::Mat s = cv::Mat::zeros(nCols,nRows*nCols,CV_32FC1);
		cv::Mat e = cv::Mat::zeros(nCols,nRows*nCols,CV_32FC1);

		std::cout<<"[NOTICE]Read the data of Imagess---->>Start!"<<std::endl;
		startTime = std::clock();
		
		dataMat =  cv::Mat::zeros(numberOfImages,nRows*nCols,CV_32FC1);
		//for(int i=0;i<numberOfImages;i++)
		for(int i=0;i<1000;i++)
		{
			for(int j=0;j<nRows*nCols;j++)
			{
				unsigned char temp = 0;
				file.read((char*)&temp,sizeof(temp));
				//std::cout<<"temp = "<<temp<<std::endl;
				float pixelValue = (float)((temp+0.0)/255.0);
				//std::cout<<"pixelValue = "<<pixelValue<<std::endl;
				dataMat.at<float>(i,j) = pixelValue;
				//打印第一张和最后一张图像数据
				if(i==0)
				{
					s.at<float>(j/nCols,j%nCols) = pixelValue;
				}
				else if(i==numberOfImages-1)
				{
					e.at<float>(j/nCols,j%nCols) = pixelValue;
				}
			}//for j
		}//for i
		endTime  = std::clock();
		constTime= (endTime-startTime);
		std::cout<<"[NOTICE]Read the data of Images---->>Finish!"<<std::endl;
		std::cout<<"[NOTICE]Running Time = "<<constTime<<"ms"<<std::endl;
		cv::imshow("firstImage",s);
		cv::imshow("last image",e);
		cv::waitKey(0);
	}
	file.close();
	return dataMat;
}
/***************************************************************************************************** 
函数功能:
         读Mnist数据库中的标签
*****************************************************************************************************/
 cv::Mat ReadMnistLabel(const string fileName) 
 {
	  double constTime;
      std::clock_t startTime;
      std::clock_t endTime;

      int magicNumber   = 0;
      int numberOfItems = 0;
 
      cv::Mat labelMat;
 
      std::ifstream file(fileName,ios::binary);
      if (file.is_open())
      {
          std::cout<<"[NOTICE]The set of Label is opened sucessfully!"<<std::endl;
          file.read((char*)&magicNumber,sizeof(magicNumber));
          file.read((char*)&numberOfItems,sizeof(numberOfItems));
          magicNumber   = ReverseInt(magicNumber);
          numberOfItems = ReverseInt(numberOfItems);
 
		  std::cout<<"[1]magicNumber    = "<<magicNumber<<std::endl;
		  std::cout<<"[2]numberOfItems  = "<<numberOfItems<<std::endl;


          //记录第一个label和最后一个label
          unsigned int s = 0;
		  unsigned int e = 0;

          std::cout<<"[NOTICE]Read the data of Labels---->>Start!"<<std::endl;
		  startTime = std::clock();
          labelMat = Mat::zeros(numberOfItems,1,CV_32SC1);
          //for (int i = 0; i < numberOfItems; i++) 
		  for (int i = 0; i < 1000; i++) 
		  {
             unsigned char temp = 0;

             file.read((char*)&temp,sizeof(temp));

             labelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
 
             //打印第一个和最后一个label
             if(i == 0)
			 {
					 s = (unsigned int)temp;
			  }
             else if(i == numberOfItems-1)
			 {
					 e = (unsigned int)temp;
			  }
          }
		  endTime = clock();
		  constTime= (endTime-startTime);
		  std::cout<<"[NOTICE]Read the data of Images---->>Finish!"<<std::endl;
		  std::cout<<"[NOTICE]Running Time = "<<constTime<<"ms"<<std::endl;
          std::cout<<"[1]first label = " << s << endl;
          std::cout<<"[2]last  label = " << e << endl;
     }
     file.close();
     return labelMat;
 }
/***************************************************************************************************** 
程序功能: 
        基于OpenCv中SVM的Minist手写体字符识别
开发环境: 
        VS2012 + OpenGl(GLUT3.7) + OpenCv2.4.9 + Halcon10.0 
时间地点: 
        陕西师范大学----2017.3.3
作    者: 
        九月 
*****************************************************************************************************/ 
#include"mnist.h"
#include<opencv2/core/core.hpp>
#include<opencv2/imgproc/imgproc.hpp>
#include<opencv2/highgui/highgui.hpp>
#include<opencv2/ml/ml.hpp>
#include<ctime>
#include<string>
#include<iostream>
  
using namespace std;
using namespace cv;


std::string trainImage = "mnist_dataset/train-images.idx3-ubyte";
std::string trainLabel = "mnist_dataset/train-labels.idx1-ubyte";
std::string testImage  = "mnist_dataset/t10k-images.idx3-ubyte";
std::string testLabel  = "mnist_dataset/t10k-labels.idx1-ubyte";

int main()
{
	double consumeTime    = 0;
	std::clock_t startTime = 0;
	std::clock_t endTime   = 0;

	cv::Mat trainData;
	cv::Mat trainDataLabels;
	//【1】读入训练样本
	trainData       = ReadMnistImage(trainImage);                        
	trainDataLabels = ReadMnistLabel(trainLabel);
	std::cout<<"[1]trainData.rows*trainData.cols             = "<<trainData.rows<<"*"<<trainData.cols<<std::endl;
	std::cout<<"[2]trainDataLabels.rows*trainDataLabels.cols = "<<trainDataLabels.rows<<"*"<<trainDataLabels.cols<<std::endl;
	//【2】设置支持向量机的参数,SVM中的参数有很多,但是与C_SVC有关的就只有gamma和C,所以只要设置好这两个就可以了
	//     其实,很多资料将gamma设置为0.01,这样训练的收敛速度就会快很多
	CvSVMParams params;
	params.svm_type    = SVM::C_SVC;
	params.kernel_type = SVM::RBF;
	params.degree      = 10.0;
	params.gamma       = 0.01;
	params.coef0       = 1.0;
	params.C           = 10.0;
	params.nu          = 0.5;
	params.p           = 0.1;
	params.term_crit   = cv::TermCriteria(CV_TERMCRIT_EPS,1000,FLT_EPSILON);
	//【3】训练SVM
	std::cout<<"[NOTICE]Starting training process!"<<std::endl;
	startTime = std::clock();
	CvSVM svm;
	svm.train(trainData,trainDataLabels,cv::Mat(),cv::Mat(),params);
	endTime   = std::clock();
	consumeTime   = (endTime - startTime);
	std::cout<<"[NOTICE]Finished training process...consumeTime = "<<consumeTime<<"ms"<<std::endl;
	svm.save("mnist_dataset/mnist_svm.xml");
	std::cout<<"[NOTICE]Save as /mnist_dataset/mnist_svm.xml"<<std::endl;
    //【4】开始导入预测样本
	std::cout<<"[NOTICE]Loading the predict sample!"<<std::endl;
    cv::Mat     testData;
	cv::Mat     testLabels;
	std::cout<<"[NOTICE]Loading sucessfully!"<<std::endl;
	testData  = ReadMnistImage(testImage);
	testLabels= ReadMnistLabel(testLabel);
    //【5】SVM利用训练好的模型开始进行预测
	float count = 0;
	//for(int i=0;i<testData.rows;i++)
	for(int i=0;i<1000;i++)
	{
		cv::Mat sample = testData.row(i);
		float result  = svm.predict(sample);
		result = std::abs(result-testLabels.at<unsigned int>(i,0)<=FLT_EPSILON?1.f:0.f);
		count += result;
	}
	//【6】统计预测的正确个数和错误率
	std::cout<<"[NOTICE]Correct identification number = "<<count<<std::endl;
	std::cout<<"[NOTICE]Error rate = "<<(1000-count+0.0)/1000 * 100.0<<"%..."<<std::endl;
	std::system("pause");
	return 0;
}
下面两幅图片是博主训练1000张图片的准确率,具体怎样设置,请看代码


  • 3
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值