#include <opencv2\opencv.hpp>
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
using namespace std;
using namespace cv;
using namespace cv::ml;
int ReverseInt(int i)
{
unsigned char ch1, ch2, ch3, ch4;
ch1 = i & 255;
ch2 = (i >> 8) & 255;
ch3 = (i >> 16) & 255;
ch4 = (i >> 24) & 255;
return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}
void read_Mnist_Label(string filename, Mat &labels,int num)
{
ifstream file(filename, ios::binary);
if (file.is_open())
{
int magic_number = 0;
int number_of_images = 0;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_images, sizeof(number_of_images));
magic_number = ReverseInt(magic_number);
number_of_images = ReverseInt(number_of_images);
cout << "magic number = " << magic_number << endl;
cout << "number of images = " << number_of_images << endl;
int row_index = 0;
for (int i = 0; i < num; i++)
{
unsigned char label = 0;
file.read((char*)&label, sizeof(label));
labels.at<int>(row_index,0) = label;
row_index++;
}
}
}
void read_Mnist_Images(string filename, Mat& images,int num)
{
ifstream file(filename, ios::binary);
if (file.is_open())
{
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
unsigned char label;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_images, sizeof(number_of_images));
file.read((char*)&n_rows, sizeof(n_rows));
file.read((char*)&n_cols, sizeof(n_cols));
magic_number = ReverseInt(magic_number);
number_of_images = ReverseInt(number_of_images);
n_rows = ReverseInt(n_rows);
n_cols = ReverseInt(n_cols);
cout << "magic number = " << magic_number << endl;
cout << "number of images = " << number_of_images << endl;
cout << "rows = " << n_rows << endl;
cout << "cols = " << n_cols << endl;
for (int i = 0; i < num; i++)
{
int col_index = 0;
for (int r = 0; r < n_rows; r++)
{
for (int c = 0; c < n_cols; c++)
{
unsigned char image = 0;
file.read((char*)&image, sizeof(image));
images.at<float>(i, col_index) = (float)image;
col_index++;
}
}
}
}
}
int main()
{
int trainNum = 10000;
int testNum = 1000;
Mat trainData = Mat::zeros(Size(784, trainNum), CV_32FC1);
Mat testData = Mat::zeros(Size(784, testNum), CV_32FC1);
Mat trainLabels = Mat::zeros(Size(1, trainNum), CV_32SC1);
Mat testLabels = Mat::zeros(Size(1, testNum), CV_32SC1);
read_Mnist_Images("train-images.idx3-ubyte", trainData,trainNum);
read_Mnist_Images("t10k-images.idx3-ubyte", testData, testNum);
read_Mnist_Label("train-labels.idx1-ubyte", trainLabels, trainNum);
read_Mnist_Label("t10k-labels.idx1-ubyte", testLabels, testNum);
trainData = trainData / 255;
testData = testData / 255;
cout << trainData.rows << " " << trainData.cols << endl;
cout << testData.rows << " " << testData.cols << endl;
//------------------------ 2. Set up the support vector machines parameters --------------------
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::C_SVC);
svm->setKernel(SVM::RBF);
svm->setDegree(0);
svm->setGamma(0.01);
svm->setCoef0(1.0);
svm->setC(10.0);
svm->setNu(0);
svm->setP(0.1);
svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
Ptr<TrainData> tdata = TrainData::create(trainData, ROW_SAMPLE, trainLabels);
//------------------------ 3. Train the svm ----------------------------------------------------
cout << "Starting training process" << endl;
double start_time_ = clock();
//svm->train(tdata);
double end_time_ = clock();
double cost_time_ = (end_time_ - start_time_) / CLOCKS_PER_SEC;
cout << "Finished training process...cost " << cost_time_ << " seconds..." << endl;
//------------------------ 4. save the svm ----------------------------------------------------
//svm->save("mnist_svm.xml");
cout << "save as /mnist_dataset/mnist_svm.xml" << endl;
//------------------------ 5. load the svm ----------------------------------------------------
cout << "开始导入SVM文件...\n";
Ptr<SVM> svm1 = StatModel::load<SVM>("mnist_svm.xml");
cout << "成功导入SVM文件...\n";
//
float count = 0;
for (int i = 0; i < testData.rows; i++) {
Mat sample = testData.row(i);
float res = svm1->predict(sample);
float realdata = testLabels.at<int>(i, 0);
res = abs(res - realdata) <= FLT_EPSILON ? 1.f : 0.f;
count+= res;
}
cout << "正确的识别个数 count = " << count << endl;
cout << "正确率为..." << (count + 0.0) / testNum * 100.0 << "%....\n";
system("pause");
return 0;
}