鸢尾花分类_使用python+sklearn实现在鸢尾花数据集上训练多分类SGD

该博客介绍了如何利用Python和scikit-learn库在鸢尾花数据集上训练多分类的Stochastic Gradient Descent (SGD)模型。内容包括绘制决策面,展示三个一对多分类器的超平面,并提到了脚本的运行时间和内存使用情况。此外,博主还邀请读者加入相关学习群进行深入交流。
摘要由CSDN通过智能技术生成

在鸢尾花数据集上绘制多分类SGD的决策面。三个一对多(one-versus-all)(OVA)分类器的超平面由虚线表示。fd6e393b5c592d35d0802d1eabc9fab3.png

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.linear_model import SGDClassifier

# 导入一些数据进行训练
iris = datasets.load_iris()

# 我们仅使用前两个特征。
# 我们可以通过使用二维数据集来避免这种用丑陋的代码来进行切分(slicing)
X = iris.data[:, :2]
y = iris.target
colors = "bry"

# 打乱数据
idx = np.arange(X.shape[0])
np.random.seed(13)
np.random.shuffle(idx)
X = X[idx]
y = y[idx]

# 标准化
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std

h = .02  # 网格中的步长

clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y)

# 创建要绘制的网格
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))

# 绘制决策边界。 为此,我们将为网格[x_min,x_max] x [y_min,y_max]中的每个点分配颜色。
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
# 将结果放入颜色图(color plot)
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis('tight')

# 绘制训练点
for i, color in zip(clf.classes_, colors):
    idx = np.where(y == i)
    plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
                cmap=plt.cm.Paired, edgecolor='black', s=20)
plt.title("Decision surface of multi-class SGD")
plt.axis('tight')

# 绘制三个一对多(one-against-all)的分类器
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
coef = clf.coef_
intercept = clf.intercept_


def plot_hyperplane(c, color):
    def line(x0):
        return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]

    plt.plot([xmin, xmax], [line(xmin), line(xmax)],
             ls="--", color=color)


for i, color in zip(clf.classes_, colors):
    plot_hyperplane(i, color)
plt.legend()
plt.show()

脚本的总运行时间:(0分钟0.585秒)

估计的内存使用量: 8 MB

9543c6c5c762ab3ba362e96a8e98693d.png

下载Python源代码: plot_sgd_iris.py

下载Jupyter notebook源代码: plot_sgd_iris.ipynb

由Sphinx-Gallery生成的画廊

3d7ad83fb09dc78f7b3806e2f0b8227e.png ☆☆☆为方便大家查阅,小编已将scikit-learn学习路线专栏文章统一整理到公众号底部菜单栏,同步更新中,关注公众号,点击左下方“系列文章”,如图: 155fcf578b2e233c518142c9abc9b2b5.png

欢迎大家和我一起沿着scikit-learn文档这条路线,一起巩固机器学习算法基础。(添加微信:mthler,备注:sklearn学习,一起进【sklearn机器学习进步群】开启打怪升级的学习之旅。

fc9cf12ab7b9e893f49229b0c884b8ec.png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值