话不多说直接上代码:
import numpy as np
import operator
import matplotlib.pyplot as plt
group = np.array([[1.0, 1.1],
[1.0, 1.0],
[0, 0],
[0, 0.1]
])
labels = ['A', 'A', 'B', 'B'] #四个点分布对于各个标签
print(group)
dataSetSize = group.shape[0] #shape不带参数表示读取数组或者矩阵的行数和列数shape[0]读取矩阵或 者数组的第一维长度其中[0,0]是测试点
multitestData = np.tile([0,0], (dataSetSize, 1)) #tile()按四行一列复制group。得到4行2列的矩阵,以便与后面4个点求距离
diffMat = multitestData - group
sqdiffMat = diffMat**2 #矩阵中每个数均会平方
sqdistance = sqdiffMat.sum(axis=1)
print(sqdistance)
distance = sqdistance**0.5 #算出了[0,0]测试点到样本点的四个欧式距离放在了列表中
print(dataSetSize)
print(np.array(multitestData))
multitestData.shape
print(sqdiffMat)
sortedDistIndex = distance.argsort() #对求得的distance距离从小到大排序并得到其索引号
print(sortedDistIndex )
classCount = {} #生成一个空字典
for i in range(3):
voteIlabel = labels[sortedDistIndex[i]]
print(voteIlabel)
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #字典get()方法通过“键值”放回“值” 最终是便签的计数
print(classCount[voteIlabel])
#通过键得到值
print(classCount.items())
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)#sorted()返回一个新的列表
#items()以列表返回可遍历的(键, 值) 元组数组
#operator.itemgetter(1)使用原组的第二个数排序(此时为升序),operator.itemgetter(1,0)使用元组第二个元素排序后再使用元组第一个元素排序
print(sortedClassCount)
print(sortedClassCount[0][0]) #输出结果[0,0]最终被分为B类
plt.plot(group[:, 0], group[:, 1], 'o')
plt.xlim(-0.1, 1.2) #x轴取值范围[-0.1,1.2]
plt.ylim(-0.1, 1.2)
plt.show()
print('classify result is:',sortedClassCount[0][0])
顺便详解下sorted函数和operator.itemgetter函数的使用
完整代码:
import numpy as np
import operator
import matplotlib.pyplot as plt
def createdataset():
"""该函数用于产生kNN实验用列,返回样本数据集(numpy数组)和样本类型集(列表)
Keyword arguments:
None
"""
group = np.array([[1.0, 1.1],
[1.0, 1.0],
[0, 0],
[0, 0.1]
])
labels = ['A', 'A', 'B', 'B']
return group, labels
def classify(testData, dataSet, labels, k):
"""应用KNN方法对测试点进行分类,返回一个结果类型
Keyword argument:
testData: 待测试点,格式为数组
dataSet: 训练样本集合,格式为矩阵
labels: 训练样本类型集合,格式为数组
k: 近邻点数
"""
dataSetSize = dataSet.shape[0]
multitestData = np.tile(testData, (dataSetSize, 1))
diffMat = multitestData - dataSet
sqdiffMat = diffMat**2
sqdistance = sqdiffMat.sum(axis=1)
print(sqdistance)
distance = sqdistance**0.5
sortedDistIndex = distance.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndex[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
print(sortedClassCount)
return sortedClassCount[0][0]
group, labels = createdataset()
print('group is:')
print(group)
print('labels is:')
print(labels)
plt.plot(group[:, 0], group[:, 1], 'o')
plt.xlim(-0.1, 1.2)
plt.ylim(-0.1, 1.2)
plt.show()
test = [0, 0]
print('test is:')
print(test)
print('classify result is:')
print(classify(test, group, labels, 3))
结果: