参考 :https://blog.csdn.net/qq_35054151/article/details/81840935
https://blog.csdn.net/weixin_33698823/article/details/94496434
C++ opencv SVM: https://www.cppentry.com/bencandy.php?fid=49&aid=153521&page=2
参数优化:https://blog.csdn.net/computerme/article/details/38677599
讲得透彻:https://blog.csdn.net/b285795298/article/details/81977271
/* SVM识别手写数字mnist 0~9 */
/* 每个数字取80个样本训练(总80*10=800个),每个数字20个验证(总20*10=200个),准确率为84%。*/
/* 每个数字取800个样本训练(总800*10=8000个),每个数字200个验证(总200*10=2000个),准确率为91%。*/
/* alpha=1.0或1.0/255 对学习效果无影响 */
/* 样本数据在 https://download.csdn.net/download/brooknew/11949332 */
#include "opencv2/highgui.hpp"
#include "opencv2/imgproc.hpp"
#include <string>
#include <iostream>
#include <fstream>
#include <vector>
#include "opencv2/ml.hpp"
using namespace std ;
using namespace cv ;
using namespace cv::ml ;
double alpha = 1.0/255 ;
string trainpath = "D:\\python\\tensorflow\\ministRecognize\\trainImgByDigit\\train_list.txt" ;
string testpath= "D:\\python\\tensorflow\\ministRecognize\\trainImgByDigit\\test_list.txt" ;
const string modelFileName ="svm_model.xml" ;
/*
* 读取样本数据和标签,输出SVM的Mat格式
*/
void get_data(string path, Mat &trainData, Mat &trainLabels)
{
fstream io(path, ios::in);
if (!io.is_open()){
cout << "file open error in path : " << path << endl;
exit(0);
}
while (!io.eof())
{
string msg;
io >> msg;
trainData.push_back(imread(msg, 0).reshape(0, 1));
io >> msg;
int idx = msg[0] - '0';
trainLabels.push_back(Mat(1, 1, CV_32S, &idx));
}
trainData.convertTo(trainData, CV_32F , alpha );
}
/*
* 训练SVM
*/
void svm_train(Ptr<SVM> &model, Mat &trainData, Mat &trainLabels)
{
model->setType(SVM::C_SVC); //SVM类型
model->setKernel(SVM::LINEAR); //核函数,这里使用线性核
//model->setKernel(SVM::POLY) ;
Ptr<TrainData> tData = TrainData::create(trainData, ROW_SAMPLE, trainLabels);
cout << "SVM: start train ..." << endl;
model->trainAuto(tData);
cout << "SVM: train success ..." << endl;
}
/*
* 利用训练好的SVM预测,以及计算准确率
*/
void svm_predict(Ptr<SVM> &model, Mat test, Mat testLabels )
{
Mat result;
float rst = model->predict(test, result);
int good = 0 ;//准确的个数
for (auto i = 0; i < result.rows; i++){
if ( (int)result.at<float>(i, 0) == (int)testLabels.at<int>(i,0) ) {
good ++ ;
}else{
;//cout <<"i=" << i <<" "<< testLabels.at<int>(i,0) <<":" << result.at<float>(i, 0) << endl ;
}
cout << result.at<float>(i, 0) << " " ;
if ( (i+1) % 10 == 0 )
cout << endl ;
}
cout << "Right:" << good << " Accurary rate: " << int((float)good/result.rows*100) << "%" << endl ;
}
int usingSvm_main(int argc, char* argv[])
{
string test_path = testpath;
string train_path = trainpath ;
Ptr<SVM> model = SVM::create();
Mat trainData, trainLabels;
get_data(train_path, trainData, trainLabels);
svm_train(model, trainData, trainLabels);
model->save( modelFileName ) ;//保存模型
Mat testData , testLabels;
get_data(test_path, testData , testLabels );
Ptr<SVM> modelV = SVM::load<SVM>( modelFileName ) ; //载入模型
svm_predict(modelV, testData , testLabels );
while( true ) ;
return 0;
}
创建训练列表的代码:
import os
import shutil
KIND = 10
filesInEachSubDir = 1000
def main():
dir = 'D:/python/tensorflow/ministRecognize/trainImgByDigit/'
with open( dir + 'test_list.txt' , 'wt' ) as f1 :
with open( dir + 'train_list.txt' , 'wt' ) as f :
for i in range(KIND):
subdir = dir + str(i) + '/'
fn = os.listdir( subdir )
nfil = min( filesInEachSubDir , len( fn ) )
nfilTrain = int(nfil*0.8)
for nf in range( nfilTrain ) :
if (i == KIND-1) and ( nf == nfilTrain-1) :
s = subdir + fn[nf] + ' ' + str( i )
else:
s = subdir + fn[nf] + ' ' + str( i ) + '\n'
f.write( s )
for nf in range( nfilTrain , nfil ) :
if (i == KIND-1) and ( nf == nfil-1) :
s = subdir + fn[nf]+ ' ' + str( i )
else :
s = subdir + fn[nf] + ' ' + str( i ) + '\n'
f1.write( s )
main()