机器学习算法——SVM简单易懂应用实战

SVM(Support Vector Machine)学习资料总结

实战应用

导入工具

import pandas as pd
import numpy as np
from sklearn import svm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from pylab import mpl
import plotly.graph_objects as go
#设置风格、尺度
sns.set_style('whitegrid')
sns.set_context('paper')

数据展示

# 一个挑西瓜的简单分类例子
data = pd.read_csv('data.csv')
data
# 数据是这样的
	编号	色泽	根蒂	敲声	纹理	脐部	触感	密度		含糖率	好瓜?
0	1	青绿	蜷缩	浊响	清晰	凹陷	硬滑	0.697	0.4601	2	乌黑	蜷缩	沉闷	清晰	凹陷	硬滑	0.774	0.3762	3	乌黑	蜷缩	浊响	清晰	凹陷	硬滑	0.634	0.2643	4	青绿	蜷缩	沉闷	清晰	凹陷	硬滑	0.608	0.3184	5	浅白	蜷缩	浊响	清晰	凹陷	硬滑	0.556	0.2155	6	青绿	稍蜷	浊响	清晰	稍凹	软粘	0.403	0.2376	7	乌黑	稍蜷	浊响	稍糊	稍凹	软粘	0.481	0.1497	8	乌黑	稍蜷	浊响	清晰	稍凹	硬滑	0.437	0.2118	9	乌黑	稍蜷	沉闷	稍糊	稍凹	硬滑	0.666	0.0919	10	青绿	硬挺	清脆	清晰	平坦	软粘	0.243	0.26710	11	浅白	硬挺	清脆	模糊	平坦	硬滑	0.245	0.05711	12	浅白	蜷缩	浊响	模糊	平坦	软粘	0.343	0.09912	13	青绿	稍蜷	浊响	稍糊	凹陷	硬滑	0.639	0.16113	14	浅白	稍蜷	沉闷	稍糊	凹陷	硬滑	0.657	0.19814	15	乌黑	稍蜷	浊响	清晰	稍凹	软粘	0.360	0.37015	16	浅白	蜷缩	浊响	模糊	平坦	硬滑	0.593	0.04216	17	青绿	蜷缩	沉闷	稍糊	稍凹	硬滑	0.719	0.103

将数据转化为数值型

# 定义一个函数将数据集转化为可处理的数值
def str_column_to_int(dataset, columns):
    """
    将类别转化为int型
    @dataset: 数据
    @column: 需要转化的列
    """
    for column in columns:
        class_values = [row[column] for row in dataset]
        unique = set(class_values)
        lookup = dict()
        for i, value in enumerate(unique):
            lookup[value] = i
        for row in dataset:
            row[column] = lookup[row[column]]
    return dataset

使用Sklearn的API

# 读取西瓜数据集
data = pd.read_csv('data.csv')
# 删除编号行
data.drop(columns=['编号'], inplace=True)
# 将数据转化为数值型
data = str_column_to_int(data.values, [0, 1, 2, 3, 4, 5, 8])
# 创建一个线性svm分类器,随便设置一个惩罚参数和内核系数
svm_classifer = svm.SVC(kernel = 'linear', C = 1, gamma = 1)
# 使用这个分类器来分类实验数据
svm_classifer.fit(data[:, :-1].astype(float), data[:, -1].astype(int))
# 查看分类器得分
svm_classifer.score(data[:, :-1].astype(float), data[:, -1].astype(int))

输出的分类准确率结果显示为

0.7058823529411765

绘制结果的混淆矩阵

# 对结果进行可视化
predict = svm_classifer.predict(data[:, :-1].astype(float))
# 设置字体
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'NSimSun,Times New Roman'
font = {
   'family': 'sans-serif',
            'color': 'k',
            'weight': 'normal',
            'size': 20,}

# 将结果整理成混淆矩阵
con_mat = confusion_matrix(data[:, -1].astype(int), predict, labels=[0, 1])
# 使用seaborn绘制
fig &#
  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值