首先,声明一下,这篇文章的目的是分享对svm工具箱的使用心得,希望能对小白、新手快速上手svm分类器,并对它有个初步的了解,有助于以后深入的了解。
所以,本文使用的svm分类器并不是我自己编写的,使用的也是网上找来的工具。
svm下载地址:https://pan.baidu.com/s/189Bc2Kz3-nhJbG1hjFkp5w
提取码:6xtp
里面包含了几个文件
heart_scale是一个心脏有关的数据集,model_file存放训练模型,result是学习结果也就是预测的类别的集合,svm_test_data是测试数据集。
svm是分类算法,svm_train是训练器,svm_test是测试器。
svm的相关概念,这里就不多说了,有不懂的可以先看看这篇博客:
https://blog.csdn.net/varyshare/article/details/90521649
更详细的的公式推导过程:(强烈推荐)
https://www.bilibili.com/video/av28186618?p=1
基础的名词解释可以去看周志华的机器学习。
对于想快速上手的,只需要修改输入的数据集,及对输入数据集做一些处理即可。这里放上对iris.data数据集的分类代码:
# coding:UTF-8
import numpy as np
import svm
import operator
from random import shuffle
def load_data_libsvm(data_file):
'''导入训练数据
input: data_file(string):训练数据所在文件
output: data(mat):训练样本的特征
label(mat):训练样本的标签
'''
f = open(data_file)
arrayLines = f.readlines()
del arrayLines[-1]
shuffle(arrayLines)
numberLines = len(arrayLines)
returnMat = np.zeros((numberLines,4))
classLabelVector = []
index = 0
for line in arrayLines:
line = line.strip()
listFromLine = line.split(',')
returnMat[index,:] = listFromLine[0:4]
if listFromLine[-1] == 'Iris-setosa':
classLabelVector.append(1)
elif listFromLine[-1] == 'Iris-versicolor':
classLabelVector.append(2)
elif listFromLine[-1] == 'Iris-virginica':
classLabelVector.append(3)
index += 1
return np.mat(returnMat),np.mat(classLabelVector).T
if __name__ == "__main__":
# 1、导入训练数据
print("------------ 加载数据 --------------")
dataSet, labels = load_data_libsvm("iris.data")
# 2、训练SVM模型
print("------------ 训练模型 ---------------")
C = 1
toler = 0.001
maxIter = 5
svm_model = svm.SVM_training(dataSet, labels, C, toler, maxIter)
# 3、计算训练的准确性
print("------------ 计算训练的正确率 --------------")
accuracy = svm.cal_accuracy(svm_model, dataSet, labels)
print("训练的正确率是: %.3f%%" % (accuracy * 100))
# 4、保存最终的SVM模型
print("------------ 保存模型 ----------------")
svm.save_svm_model(svm_model, "model_file")
这里的C是惩罚参数,表示分类过程中对间隔大小和分类准确度的偏好的权重,C越大,表示能容忍的误差越小,容易过拟合;C越小,表示允许的误差越大,容易欠拟合。
maxIter表示允许的最大迭代次数。
数据集iris.data在UCI上可以自己下载。
放上测试器的代码:
import numpy as np
import pickle a