基于sklearn的简单分类器
输入 输出
3 1 0
2 5 1
1 8 1
6 4 0
5 2 0
3 5 1
4 7 1
4 -1 0
7 5 ?
已知部分输入和部分输出求当输入为7、5时输出为多少?
我们观察上面的规律不难发现,当输入的第一个数大于第二个数时输出为0,当输入的第一个数小于第二个数时输出为1,因此我们可根据输入和输出关系利用matplotlib.pyplot在坐标系中将输入对应点将输出对应点的颜色表示出来。
"""
简单的分类器
"""
# 导入需求模块
import numpy as np
import matplotlib.pyplot as mp
x = np.array([
[3, 1],
[2, 5],
[1, 8],
[6, 4],
[5, 2],
[3, 5],
[4, 7],
[4, -1]])
y = np.array([0, 1, 1, 0, 0, 1, 1, 0])
# 文件名称
mp.figure("Simple Classification", facecolor="lightgray")
# 设置标题
mp.title("Simple Classification", fontsize=20)
# x坐标
mp.xlabel("x", fontsize=16)
# y坐标
mp.ylabel("y", fontsize=16)
# 坐标大小
mp.tick_params(labelsize=10)
# 输出为散点图并赋予点的颜色
mp.scatter(x[:, 0], x[:, 1], c=y, cmap="brg", s=60)
# 窗口显示
mp.show()
显示效果如下:
倘若要对这两类输出做一个简单的分类我们假设可以在图像上做出一条y=kx的直线,位于直线上方的点都视为输出为1的类,位于直线下方的点都视为输出为0的类,而这条y=kx的直线即为0类和1类的分界线,由于数据样本有限我们可以做出以上假设。
假如该平面有若干个点,他们密集的分布在该平面内,我们利用上述发现的规律——将第一个数字大于第二个数字的输入标记为0,将第一个数字小于第二个数字的输入标记为1,然后在图像上显示出来,他们所构成的图形的界限即为分类边界。
"""
简单的分类器
"""
# 导入需求模块
import numpy as np
import matplotlib.pyplot as mp
x = np.array([
[3, 1],
[2, 5],
[1, 8],
[6, 4],
[5, 2],
[3, 5],
[4, 7],
[4, -1]])
y = np.array([0, 1, 1, 0, 0, 1, 1, 0])
# 绘制分类边界
# l:左边界,r:右边界,h:点和点的水平距离(为了防止点在边界上,我们把边界再向外移动一个像素的距离)
l, r, h = x[:, 0].min()-1, x[:, 0].max()+1, 0.05
# b:底边界,t:顶边界,v:点和点的竖直距离
b, t, v = x[:, 1].min()-1, x[:, 1].max()+1, 0.05
# 生成二维点阵
grid_x = np.meshgrid(np.arange(l, r, h), np.arange(b, t, v))
# 将二维点阵展平拼接成若干行两列的二维数组
flat_x = np.c_[grid_x[0].ravel(), grid_x[1].ravel()]
# 初始化y,并定义数据类型为int型
flat_y = np.zeros(len(flat_x), dtype=int)
# 将第一列小于第二列的标记为1
flat_y[flat_x[:, 0] < flat_x[:, 1]] = 1
flat_y[flat_x[:, 0] > flat_x[:, 1]] = 0
grid_y = flat_y.reshape(grid_x[0].shape)
# 文件名称
mp.figure("Simple Classification", facecolor="lightgray")
# 设置标题
mp.title("Simple Classification", fontsize=20)
# x坐标
mp.xlabel("x", fontsize=16)
# y坐标
mp.ylabel("y", fontsize=16)
# 坐标大小
mp.tick_params(labelsize=10)
# 用颜色绘制网格
mp.pcolormesh(grid_x[0], grid_x[1], grid_y, cmap="gray")
# 输出为散点图并赋予点的颜色
mp.scatter(x[:, 0], x[:, 1], c=y, cmap="brg", s=60)
# 窗口显示
mp.show()
显示效果如下:
我们可以看到,根据我们的假设,使用若干个点来对模型进行检验,分界线确实为y=kx的直线,样本点依然位于分类标准两侧,因此假设成立。
注:仔细放大图片位于黑白交界处的棱角受水平和竖直的步长影响,为了提高图形质量可缩小步长即可,笔者为了提高运行效率因此步长选用0.05。