sklearn中SVM的可视化

本文代码参考博主twilight0402文章

关于sklearn中SVM使用的文章很多,但其可视化部分往往给出代码但解释通常都不是很详细,本文主要对sklearn训练SVM后的可视化过程进行阐述,故对SVM的使用不做赘述,可以参考其他文章。

1.导入数据:本文使用sklearn中的moon、circle生成数据

import numpy as np
from sklearn import svm, datasets
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# 生成弯月形数据
Data = datasets.make_moons(n_samples=100, noise=0.2, random_state=11)  

# # 生成环形数据
# Data = datasets.make_circles(n_samples=100, noise=0.1, factor=0.1, random_state=11)    

X = Data[0]     # 特征 (100, 2)
y = Data[1]     # 标签,0或1 (100,)

2.构造SVM模型:参考文章,这里使用pipeline构造包含数据归一化和SVM的模型

def creatSVM(kernel='rbf', C=10, gamma=10):
    """
    :param kernel: SVM使用的核函数,本文以rbf核为例,常用于处理分线性分类问题
    :param C: 目标函数惩罚项系数
    :param gamma: rbf核函数参数。实际训练模型时C和gamma需要调参
    """
    return Pipeline([
        ("scaler", StandardScaler()),
        ("SVM", svm.SVC(kernel=kernel, gamma=gamma, C=C))
    ])

mySVM = creatSVM()   # 初始化模型
mySVM.fit(X, y)      # 训练SVM模型

3.模型可视化:

        SVM的可视化主要是希望画出模型分界面(决策函数),当kernel=‘linear’时(即线性SVM)可以直接得到分界面(wx+b=0)参数.coef_[0](对应w)、.intercept_[0](对应b),由此可以计算并绘制分界面。但使用其他核函数时由于实际上进行了特征变换,故无法直接得到分界面参数。
        此时通常采用模拟的方法绘制近似分界面。具体来说,使用训练好的SVM计算指定区域内所有点(当然在算法中实际为抽样的有限个点)的预测结果,利用绘制等高线的函数绘制以输入特征为底面,预测结果为高度的等高线图,对应不同高度(预测结果)的点的分界面即为近似的分界面。

x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
x2_min, x2_max = X[:, 1].min()-1, X[:, 1].max() + 1
# 获得绘图边界,这里没有区分训练数据或测试数据,根据实际需求选择即可
h = (x1_max - x1_min) / 100   
# h为采样点间隔,可以自己设定
xx, yy = np.meshgrid(np.arange(x1_min, x1_max, h), np.arange(x2_min, x2_max, h))
# 由meshgrid函数生成对应区域内所有点的横纵坐标,xx、yy均为尺寸为(M, N)的二维矩阵,分别对应区域内所有点的横坐标和所有点的纵坐标,同时也是区域内所有样本的第一维特征和第二维特征
z = mySVM.predict(np.c_[xx.ravel(), yy.ravel()])
# 由训练好的SVM预测区域内所有样本的结果。由于xx、yy尺寸均为(M,N),通过.ravel拉平并通过.c_组合,尺寸变为(M*N, 2),相当于M*N个具有两维特征的样本,输出z尺寸为(M*N,)
z = z.reshape(xx.shape)
# 将输出尺寸也转变为(M, N)以和横纵坐标对应绘制等高线图
plt.contourf(xx, yy, z, cmap=plt.cm.ocean, alpha=0.6)
# 绘制等高线图
plt.scatter(X[y == 0, 0], X[y == 0, 1])
plt.scatter(X[y == 1, 0], X[y == 1, 1])
# 标记数据中各样本
plt.title('Visualization of SVM with RBF kernel')
plt.xlabel('x1')
plt.ylabel('x2')
plt.show()

绘制结果如下,由于本文是二分类问题,途中只有两种颜色,对应不同分类结果。不同颜色的等高线界面即为近似分界面

moon数据分类结果
circle数据分类结果

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值