一、介绍
-
支持向量机(support vector machines,SVM):
- 寻找一个超平面对样本进行分割 + 分割原则是间隔最大化 + 最终转换为一个凸二次规划问题求解,线性、非线性和分类回归问题均可以。
- 训练样本线性可分时,使用硬间隔最大化;训练样本线性不可分时,使用核函数和软间隔最大化。
- 核函数:将样本从原始空间映射到一个更高维的特征空间,使得样本在这个特征空间内线性可分,即使得线性不可分的数据变得线性可分。(常用的核函数:线性核、多项式核、高斯核RBF、拉普拉斯核、Sigmoid核。)
-
SVM主要解决了两类问题:
- 寻找最优的超平面(寻找一个超平面可以使得与它最近的样本点的距离大于其他所有超平面划分时与最近样本点的距离,间隔最大化)。
- 能够划分非线性可分的样本。
二、代码实现
from sklearn.svm import LinearSVC, LinearSVR, NuSVC, NuSVR, OneClassSVM, SVC, SVR
的参数和属性
-
OVR :一对多法(one-versus-rest, 简称OVR SVMs)
把某个类别归为一类,其他剩余的类别样本归为另一类,这样k个类别的样本就构造了k个SVM分类器,预测结果取k个分类器中值最大的一个作为最后的预测输出。如A, B, C, D四类,A为一类,剩下B, C, D为一类,以此类推。 -
OVO: 一对一法(one-versus-one, 简称OVO SVMs或者pairwise):
在任意两类样本之间训练一个SVM分类器,所以k个类别的样本就需要 k ( k − 1 ) / 2 k(k-1)/2 k(k−1)/2个SVM分类器。类别多的时候代价较大。
代码实现
#coding=utf-8
# coding=utf-8
import time
import numpy as np
from sklearn.svm import SVC
from sklearn import preprocessing
# 1. 加载数据
def loadData(fileName):
'''
加载Mnist数据集
:param fileName:要加载的数据集路径
:return: list形式的数据集及标记
'''
print('start to read data')
# 存放数据及标记的list
dataArr = []
labelArr = []
# 打开文件
fr = open(fileName, 'r')
# 将文件按行读取
for line in fr.readlines():
# 对每一行数据按切割福','进行切割,返回字段列表
curLine = line.strip().split(',')
# Mnsit有0-9是个标记,由于是二分类任务,所以将>=5的作为1,<5为-1
# if int(curLine[0]) >= 5:
# labelArr.append(1)
# else:
# labelArr.append(-1)
labelArr.append(int(curLine[0]))
# 存放标记
# [int(num) for num in curLine[1:]] -> 遍历每一行中除了以第一哥元素(标记)外将所有元素转换成int类型
# [int(num)/255 for num in curLine[1:]] -> 将所有数据除255归一化(非必须步骤,可以不归一化)
# dataArr.append([int(num)/255 for num in curLine[1:]])
dataArr.append([num for num in curLine[1:]])
labelArr = np.ravel(labelArr)
dataArr = np.array(dataArr)
# dataArr = preprocessing.StandardScaler().fit(dataArr)
#返回data和label
return dataArr, labelArr
def svm_model(trainDataList, trainLabelList, iter=2000):
model = SVC(kernel='rbf', max_iter=iter, decision_function_shape='ovo')
model.fit(trainDataList, trainLabelList)
return model
def model_test(model, testDataList, testLabelList):
accuracy = model.score(testDataList, testLabelList)
return accuracy
if __name__ == '__main__':
start = time.time()
# 获取训练集及标签
print('start read transSet')
trainData, trainLabel = loadData('../Mnist/mnist_train.csv')
# 获取测试集及标签
print('start read testSet')
testData, testLabel = loadData('../Mnist/mnist_test.csv')
# 开始训练,学习w
print('start to train')
model = svm_model(trainData, trainLabel)
# 验证正确率
print('start to test')
accuracy = model_test(model, testData, testLabel)
# 打印准确率
print('the accuracy is:', accuracy)
# 打印时间
print('time span:', time.time() - start)
ps:本博客仅供自己复习理解,不具其他人可参考,本博客参考了大量的优质资源,侵删。