opencv SVM的使用

7 篇文章 0 订阅

参考 :https://blog.csdn.net/qq_35054151/article/details/81840935
https://blog.csdn.net/weixin_33698823/article/details/94496434 
C++ opencv SVM: https://www.cppentry.com/bencandy.php?fid=49&aid=153521&page=2 
参数优化:https://blog.csdn.net/computerme/article/details/38677599

讲得透彻:https://blog.csdn.net/b285795298/article/details/81977271

 

/* SVM识别手写数字mnist 0~9 */
/* 每个数字取80个样本训练(总80*10=800个),每个数字20个验证(总20*10=200个),准确率为84%。*/
/* 每个数字取800个样本训练(总800*10=8000个),每个数字200个验证(总200*10=2000个),准确率为91%。*/
/* alpha=1.0或1.0/255 对学习效果无影响 */
/* 样本数据在 https://download.csdn.net/download/brooknew/11949332 */



#include "opencv2/highgui.hpp"
#include "opencv2/imgproc.hpp"
#include <string>
#include <iostream>
#include <fstream>
#include <vector>
#include "opencv2/ml.hpp"

using namespace std ;
using namespace cv ;
using namespace cv::ml ; 

double alpha = 1.0/255 ;
string trainpath = "D:\\python\\tensorflow\\ministRecognize\\trainImgByDigit\\train_list.txt" ;
string testpath= "D:\\python\\tensorflow\\ministRecognize\\trainImgByDigit\\test_list.txt" ;
const string modelFileName ="svm_model.xml" ;
 
/*
* 读取样本数据和标签,输出SVM的Mat格式
*/
void get_data(string path, Mat &trainData, Mat &trainLabels)
{
    fstream io(path, ios::in);
    if (!io.is_open()){
        cout << "file open error in path : " << path << endl;
        exit(0);
    }
 
    while (!io.eof())
    {
        string msg;
        io >> msg;
 
        trainData.push_back(imread(msg, 0).reshape(0, 1));
 
        io >> msg;
        int idx = msg[0] - '0';
        trainLabels.push_back(Mat(1, 1, CV_32S, &idx));
    }
 
	trainData.convertTo(trainData, CV_32F , alpha );
}
 
/*
* 训练SVM
*/
void svm_train(Ptr<SVM> &model, Mat &trainData, Mat &trainLabels)
{
    model->setType(SVM::C_SVC);     //SVM类型
    model->setKernel(SVM::LINEAR);  //核函数,这里使用线性核
	//model->setKernel(SVM::POLY) ;
    Ptr<TrainData> tData = TrainData::create(trainData, ROW_SAMPLE, trainLabels);
 
    cout << "SVM: start train ..." << endl;
    model->trainAuto(tData);
    cout << "SVM: train success ..." << endl;
}
 
/*
* 利用训练好的SVM预测,以及计算准确率
*/
void svm_predict(Ptr<SVM> &model, Mat test, Mat testLabels )
{
    Mat result;
    float rst = model->predict(test, result);
	int good = 0 ;//准确的个数
    for (auto i = 0; i < result.rows; i++){
		if ( (int)result.at<float>(i, 0) == (int)testLabels.at<int>(i,0) ) {
			good ++ ;
		}else{
			;//cout <<"i=" << i <<"  "<< testLabels.at<int>(i,0) <<":" << result.at<float>(i, 0) << endl ; 
		}
        cout << result.at<float>(i, 0) << "  " ;
		if ( (i+1) % 10 == 0 )
			cout << endl ;
    }
	cout << "Right:" << good << "  Accurary rate: " << int((float)good/result.rows*100) << "%" << endl ; 
}
 
int usingSvm_main(int argc, char* argv[])
{
    string test_path = testpath; 
    string train_path = trainpath ;
 
    Ptr<SVM> model = SVM::create();
    Mat trainData, trainLabels;
    get_data(train_path, trainData, trainLabels);
    svm_train(model, trainData, trainLabels);
	model->save( modelFileName  ) ;//保存模型
 
    Mat testData , testLabels;
    get_data(test_path, testData , testLabels );
    Ptr<SVM> modelV = SVM::load<SVM>( modelFileName ) ; //载入模型
	svm_predict(modelV, testData , testLabels  );
	while( true ) ;
	return 0;
}

创建训练列表的代码:
 

import os
import shutil
KIND = 10
filesInEachSubDir = 1000 

def main():
    dir = 'D:/python/tensorflow/ministRecognize/trainImgByDigit/'
    with open( dir + 'test_list.txt' , 'wt' ) as f1 :    
        with open( dir + 'train_list.txt' , 'wt' ) as f :
            for i in range(KIND):
                subdir = dir + str(i) + '/'
                fn = os.listdir( subdir )
                nfil = min( filesInEachSubDir , len( fn ) )
                nfilTrain = int(nfil*0.8)  
                for nf in range( nfilTrain ) :
                    if (i == KIND-1) and ( nf == nfilTrain-1) :
                        s = subdir + fn[nf] + ' ' + str( i ) 
                    else:
                        s = subdir + fn[nf] + ' ' + str( i ) + '\n'
                    f.write( s )
                for nf in range( nfilTrain , nfil ) :
                    if (i == KIND-1) and ( nf == nfil-1) :
                        s = subdir + fn[nf]+ ' ' + str( i )
                    else :
                        s = subdir + fn[nf] + ' ' + str( i ) + '\n'
                    f1.write( s )

main() 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值