python实现svm

首先,声明一下,这篇文章的目的是分享对svm工具箱的使用心得,希望能对小白、新手快速上手svm分类器,并对它有个初步的了解,有助于以后深入的了解。
所以,本文使用的svm分类器并不是我自己编写的,使用的也是网上找来的工具。

svm下载地址:https://pan.baidu.com/s/189Bc2Kz3-nhJbG1hjFkp5w
提取码:6xtp
里面包含了几个文件在这里插入图片描述
heart_scale是一个心脏有关的数据集,model_file存放训练模型,result是学习结果也就是预测的类别的集合,svm_test_data是测试数据集。
svm是分类算法,svm_train是训练器,svm_test是测试器。

svm的相关概念,这里就不多说了,有不懂的可以先看看这篇博客:
https://blog.csdn.net/varyshare/article/details/90521649
更详细的的公式推导过程:(强烈推荐)
https://www.bilibili.com/video/av28186618?p=1
基础的名词解释可以去看周志华的机器学习。

对于想快速上手的,只需要修改输入的数据集,及对输入数据集做一些处理即可。这里放上对iris.data数据集的分类代码:

# coding:UTF-8

import numpy as np
import svm
import operator
from random import shuffle

def load_data_libsvm(data_file):
    '''导入训练数据
    input:  data_file(string):训练数据所在文件
    output: data(mat):训练样本的特征
            label(mat):训练样本的标签
    '''
    f = open(data_file)
    arrayLines = f.readlines()
    del arrayLines[-1]
    shuffle(arrayLines)
    numberLines = len(arrayLines)
    returnMat = np.zeros((numberLines,4))
    classLabelVector = []
    index = 0
    for line in arrayLines:
        line = line.strip()
        listFromLine = line.split(',')
        returnMat[index,:] = listFromLine[0:4]
        if listFromLine[-1] == 'Iris-setosa':
            classLabelVector.append(1)
        elif listFromLine[-1] == 'Iris-versicolor':
            classLabelVector.append(2)
        elif listFromLine[-1] == 'Iris-virginica':
            classLabelVector.append(3)
        index += 1
    return np.mat(returnMat),np.mat(classLabelVector).T

if __name__ == "__main__":
    # 1、导入训练数据
    print("------------ 加载数据 --------------")
    dataSet, labels = load_data_libsvm("iris.data")
    # 2、训练SVM模型
    print("------------ 训练模型 ---------------")
    C = 1
    toler = 0.001
    maxIter = 5
    svm_model = svm.SVM_training(dataSet, labels, C, toler, maxIter)
    # 3、计算训练的准确性
    print("------------ 计算训练的正确率 --------------")
    accuracy = svm.cal_accuracy(svm_model, dataSet, labels)  
    print("训练的正确率是: %.3f%%" % (accuracy * 100))
    # 4、保存最终的SVM模型
    print("------------ 保存模型 ----------------")
    svm.save_svm_model(svm_model, "model_file")
    

这里的C是惩罚参数,表示分类过程中对间隔大小和分类准确度的偏好的权重,C越大,表示能容忍的误差越小,容易过拟合;C越小,表示允许的误差越大,容易欠拟合。
maxIter表示允许的最大迭代次数。
数据集iris.data在UCI上可以自己下载。

放上测试器的代码:


import numpy as np
import pickle a
  • 7
    点赞
  • 56
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值