在鸢尾花数据集上绘制多分类SGD的决策面。三个一对多(one-versus-all)(OVA)分类器的超平面由虚线表示。
print(__doc__)
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.linear_model import SGDClassifier
# 导入一些数据进行训练
iris = datasets.load_iris()
# 我们仅使用前两个特征。
# 我们可以通过使用二维数据集来避免这种用丑陋的代码来进行切分(slicing)
X = iris.data[:, :2]
y = iris.target
colors = "bry"
# 打乱数据
idx = np.arange(X.shape[0])
np.random.seed(13)
np.random.shuffle(idx)
X = X[idx]
y = y[idx]
# 标准化
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std
h = .02 # 网格中的步长
clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y)
# 创建要绘制的网格
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
# 绘制决策边界。 为此,我们将为网格[x_min,x_max] x [y_min,y_max]中的每个点分配颜色。
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
# 将结果放入颜色图(color plot)
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis('tight')
# 绘制训练点
for i, color in zip(clf.classes_, colors):
idx = np.where(y == i)
plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
cmap=plt.cm.Paired, edgecolor='black', s=20)
plt.title("Decision surface of multi-class SGD")
plt.axis('tight')
# 绘制三个一对多(one-against-all)的分类器
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
coef = clf.coef_
intercept = clf.intercept_
def plot_hyperplane(c, color):
def line(x0):
return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]
plt.plot([xmin, xmax], [line(xmin), line(xmax)],
ls="--", color=color)
for i, color in zip(clf.classes_, colors):
plot_hyperplane(i, color)
plt.legend()
plt.show()
脚本的总运行时间:(0分钟0.585秒)
估计的内存使用量: 8 MB
下载Python源代码: plot_sgd_iris.py
下载Jupyter notebook源代码: plot_sgd_iris.ipynb
由Sphinx-Gallery生成的画廊
☆☆☆为方便大家查阅,小编已将scikit-learn学习路线专栏文章统一整理到公众号底部菜单栏,同步更新中,关注公众号,点击左下方“系列文章”,如图:欢迎大家和我一起沿着scikit-learn文档这条路线,一起巩固机器学习算法基础。(添加微信:mthler,备注:sklearn学习,一起进【sklearn机器学习进步群】开启打怪升级的学习之旅。