opencv3 svm实现手写数字集minist分类

#include <opencv2\opencv.hpp>
#include <iostream>
#include <fstream>
#include <string>
#include <vector>

using namespace std;
using namespace cv;
using namespace cv::ml;

int ReverseInt(int i)
{
    unsigned char ch1, ch2, ch3, ch4;
    ch1 = i & 255;
    ch2 = (i >> 8) & 255;
    ch3 = (i >> 16) & 255;
    ch4 = (i >> 24) & 255;
    return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}

void read_Mnist_Label(string filename, Mat &labels,int num)
{
    ifstream file(filename, ios::binary);
    if (file.is_open())
    {
        int magic_number = 0;
        int number_of_images = 0;
        file.read((char*)&magic_number, sizeof(magic_number));
        file.read((char*)&number_of_images, sizeof(number_of_images));
        magic_number = ReverseInt(magic_number);
        number_of_images = ReverseInt(number_of_images);
        cout << "magic number = " << magic_number << endl;
        cout << "number of images = " << number_of_images << endl;
        int row_index = 0;
        for (int i = 0; i < num; i++)
        {
            unsigned char label = 0;
            file.read((char*)&label, sizeof(label));
            labels.at<int>(row_index,0) = label;
            row_index++;
        }
    }
}

void read_Mnist_Images(string filename, Mat& images,int num)
{
    ifstream file(filename, ios::binary);
    if (file.is_open())
    {
        int magic_number = 0;
        int number_of_images = 0;
        int n_rows = 0;
        int n_cols = 0;
        unsigned char label;
        file.read((char*)&magic_number, sizeof(magic_number));
        file.read((char*)&number_of_images, sizeof(number_of_images));
        file.read((char*)&n_rows, sizeof(n_rows));
        file.read((char*)&n_cols, sizeof(n_cols));
        magic_number = ReverseInt(magic_number);
        number_of_images = ReverseInt(number_of_images);
        n_rows = ReverseInt(n_rows);
        n_cols = ReverseInt(n_cols);
        cout << "magic number = " << magic_number << endl;
        cout << "number of images = " << number_of_images << endl;
        cout << "rows = " << n_rows << endl;
        cout << "cols = " << n_cols << endl;
        for (int i = 0; i < num; i++)
        {
            int col_index = 0;
            for (int r = 0; r < n_rows; r++)
            {
                for (int c = 0; c < n_cols; c++)
                {
                    unsigned char image = 0;
                    file.read((char*)&image, sizeof(image));
                    images.at<float>(i, col_index) = (float)image;
                    col_index++;
                }
            }
        }
    }
}
int main()
{
    int trainNum = 10000;
    int testNum = 1000;
    Mat trainData = Mat::zeros(Size(784, trainNum), CV_32FC1);
    Mat testData = Mat::zeros(Size(784, testNum), CV_32FC1);
    Mat trainLabels = Mat::zeros(Size(1, trainNum), CV_32SC1);
    Mat testLabels = Mat::zeros(Size(1, testNum), CV_32SC1);
    read_Mnist_Images("train-images.idx3-ubyte", trainData,trainNum);
    read_Mnist_Images("t10k-images.idx3-ubyte", testData, testNum);
    read_Mnist_Label("train-labels.idx1-ubyte", trainLabels, trainNum);
    read_Mnist_Label("t10k-labels.idx1-ubyte", testLabels, testNum);

    trainData = trainData / 255;
    testData = testData / 255;

    cout << trainData.rows << " " << trainData.cols << endl;
    cout << testData.rows << " " << testData.cols << endl;

 //------------------------ 2. Set up the support vector machines parameters --------------------
    Ptr<SVM> svm = SVM::create();
    svm->setType(SVM::C_SVC);
    svm->setKernel(SVM::RBF);
    svm->setDegree(0);
    svm->setGamma(0.01);
    svm->setCoef0(1.0);
    svm->setC(10.0);
    svm->setNu(0);
    svm->setP(0.1);
    svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
    Ptr<TrainData> tdata = TrainData::create(trainData, ROW_SAMPLE, trainLabels);
    //------------------------ 3. Train the svm ----------------------------------------------------
    cout << "Starting training process" << endl;
    double start_time_ = clock();
    //svm->train(tdata);
    double end_time_ = clock();
    double cost_time_ = (end_time_ - start_time_) / CLOCKS_PER_SEC;
    cout << "Finished training process...cost " << cost_time_ << " seconds..." << endl;
    
    //------------------------ 4. save the svm ----------------------------------------------------
    //svm->save("mnist_svm.xml");
    cout << "save as /mnist_dataset/mnist_svm.xml" << endl;
    //------------------------ 5. load the svm ----------------------------------------------------
    cout << "开始导入SVM文件...\n";
    Ptr<SVM> svm1 = StatModel::load<SVM>("mnist_svm.xml");
    cout << "成功导入SVM文件...\n";
    //
    float count = 0;
    for (int i = 0; i < testData.rows; i++) {
        Mat sample = testData.row(i);
        float res = svm1->predict(sample);
        float realdata  = testLabels.at<int>(i, 0);
        res = abs(res - realdata) <= FLT_EPSILON ? 1.f : 0.f;
        count+= res;
    }
    cout << "正确的识别个数 count = " << count << endl;
    cout << "正确率为..." << (count + 0.0) / testNum * 100.0 << "%....\n";
    system("pause");
    return 0;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值