机器学习:SVM的代码实现

目录

前言

一、完整代码

二、输出结果

三、实现步骤解析

1.读取数据

2.创建模型并训练

3.可视化SVM结果

总结


前言

        支持向量机(SVM,Support Vector Machine)是一种用于分类和回归的监督学习算法。它的核心思想是通过在特征空间中找到一个最佳的分隔超平面来将数据分成不同的类别。

 

一、完整代码

import pandas as pd

# 读取数据
data = pd.read_csv('iris.csv', header=None)

"""
使用SVM进行训练
"""
from sklearn.svm import SVC  # SVC做分类  SVR做回归

# 获取特征和标签
x = data.iloc[:, [1, 3]]
y = data.iloc[:, -1]
svm = SVC(kernel='linear', C=10, random_state=0)   # C=float('inf')将软间隔的惩罚设置为无穷大
svm.fit(x, y)

"""
可视化SVM结果
"""
# 参数w[原始数据为二维数组]
w = svm.coef_[0]
# 偏置项b[原始数据为一维数组]
b = svm.intercept_[0]
# 超平面方程:w1x1+w2x2+b=0
# ->>x2 = -(w1x1+b)/w2
import numpy as np

x1 = np.linspace(0, 7, 300)  # 在0-7内生成300个数据
# 超平面方程
x2 = -(w[0] * x1 + b) / w[1]
# 上超平面方程
x3 = (1 - (w[0] * x1 + b)) / w[1]
# 下超平面方程
x4 = (-1 - (w[0] * x1 + b)) / w[1]

import matplotlib.pyplot as plt

data1 = data.iloc[:50, :]
data2 = data.iloc[50:, :]
# 原数据为四维 无法展示 这里选择两个特征进行二维展示
plt.scatter(data1[1], data1[3], marker='+')
plt.scatter(data2[1], data2[3], marker='o')
# plt.show()

# 可视化超平面
plt.plot(x1, x2, linewidth=2, color='r')
plt.plot(x1, x3, linewidth=1, color='r', linestyle='--')
plt.plot(x1, x4, linewidth=1, color='r', linestyle='--')
# 进行坐标轴限制
plt.xlim(4, 7)
plt.ylim(0, 6)

# 可视化支持向量
vets = svm.support_vectors_
plt.scatter(vets[:, 0], vets[:, 1], c='b', marker='x')
plt.show()

 

二、输出结果

 

三、实现步骤解析

1.读取数据

  • 这里使用的是鸢尾花的数据
import pandas as pd

# 读取数据
data = pd.read_csv('iris.csv', header=None)

 

2.创建模型并训练

  • svm模型里的C参数可以用来控制惩罚力度进而控制软间隔的程度
    • C越大,惩罚越严格,软间隔程度越小,越准确,但也越容易过拟合
    • C越小,惩罚越不严格,软间隔程度越大,越不准确,但也越不容易过拟合
"""
使用SVM进行训练
"""
from sklearn.svm import SVC  # SVC做分类  SVR做回归

# 获取特征和标签
x = data.iloc[:, [1, 3]]
y = data.iloc[:, -1]
svm = SVC(kernel='linear', C=10, random_state=0)   # C=float('inf')将软间隔的惩罚设置为无穷大
svm.fit(x, y)

 

3.可视化SVM结果

  • 获取svm模型里返回的系数和截距
  • 再通过系数和截距求出各直线方程
  • 最后进行二维的展示
"""
可视化SVM结果
"""
# 参数w[原始数据为二维数组]
w = svm.coef_[0]
# 偏置项b[原始数据为一维数组]
b = svm.intercept_[0]
import numpy as np

x1 = np.linspace(0, 7, 300)  # 在0-7内生成300个数据
# 超平面方程
x2 = -(w[0] * x1 + b) / w[1]
# 上超平面方程
x3 = (1 - (w[0] * x1 + b)) / w[1]
# 下超平面方程
x4 = (-1 - (w[0] * x1 + b)) / w[1]

import matplotlib.pyplot as plt

data1 = data.iloc[:50, :]
data2 = data.iloc[50:, :]
# 原数据为四维 无法展示 这里选择两个特征进行二维展示
plt.scatter(data1[1], data1[3], marker='+')
plt.scatter(data2[1], data2[3], marker='o')
# plt.show()

# 可视化超平面
plt.plot(x1, x2, linewidth=2, color='r')
plt.plot(x1, x3, linewidth=1, color='r', linestyle='--')
plt.plot(x1, x4, linewidth=1, color='r', linestyle='--')
# 进行坐标轴限制
plt.xlim(4, 7)
plt.ylim(0, 6)

# 可视化支持向量
vets = svm.support_vectors_
plt.scatter(vets[:, 0], vets[:, 1], c='b', marker='x')
plt.show()

 

总结

        总的来说,SVM可以使用核函数处理非线性问题,通过将数据映射到更高维的空间。正则化参数C控制分类准确性与模型复杂度之间的平衡。SVM广泛应用于文本分类、图像识别、生物信息学和金融预测等领域。

  • 9
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

吃什么芹菜卷

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值