[题外话]近期申请了一个微信公众号:平凡程式人生。有兴趣的朋友可以关注,那里将会涉及更多机器学习、OpenCL+OpenCV以及图像处理方面的文章。
2.2 简单实例
为了验证前面实现的K近邻算法正确性,先设计一个简单实例。
该实例中,在code中调用函数createDataSet()创建了样本数据group及分类 labels,具体实现如下:
def createDataSet():
group = array([[1.1, 1.1], [1.0, 1.2], [0.2, 0.4], [0.9, 0.1]]) #创建4x2的数组作为训练样本
labels = ['A', 'A', 'B', 'B'] #4个训练样本所属类的标记
return group, labels
然后调用模块kNN的函数kNN.knnClassify()对应测试样本[0.2, 0]进行分类测试,其分类结果为B,符合预期。具体实现如下:
group, labels = createDataSet() #创建训练样本和对应标记
realIndex = kNN.knnClassify([0.2, 0], group, labels, 3) #检测测试样本[0.2, 0]属于哪个类
print realIndex
为了清楚地看到训练样本各个类别的分类情况,可以调用matplotlib绘制了样本数据的散列图。具体实现如下:
#Matplotlib 里的常用类的包含关系为 Figure -> Axes -> (Line2D, Text, etc.)
#一个Figure对象可以包含多个子图(Axes),在matplotlib中用Axes对象表示一个绘图区域,可以理解为子图。
fig = plt.figure() #创建图表fig
ax = fig.add_subplot(1, 1, 1) #在图表fig中创建一个子图ax
#绘制散列图,前两个参数表示x轴和y轴所要显示