# -*- coding: utf-8 -*-
"""
Created on Sun Oct 21 19:47:08 2018
@author: 国涛
"""
import numpy as np
import matplotlib.pyplot as plt
import operator
def createDataSet():
group = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels =['A','A','B','B']
return group,labels
def createMap(group,labels):
[row,col] = group.shape
plt.xlim(-0.2,1.2)
plt.ylim(-0.2,1.2)
for i in range(row):
if(labels[i] == 'A'):
plt.plot(group[i,0],group[i,1],'ro')
plt.text(group[i,0]-0.05,group[i,1],str(labels[i]), ha='right', wrap=True)
else:
plt.plot(group[i,0],group[i,1],'bo')
plt.text(group[i,0]-0.05,group[i,1],str(labels[i]), ha='right', wrap=True)
def classify0(inX,dataSet,labels,k):
'''
inX:输入的测试集
dataSet:数据集
labels:标签
k:选择距离最小的k个点
'''
#计算数据集的个数
dataSetSize = dataSet.shape[0]
#计算inX与数据集直接的距离差
diffMat = np.tile(inX,(dataSetSize,1)) - dataSet
#计算平方和
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
#计算距离
distance = sqDistances**0.5
#根据距离进行排序
sortedDistIndices = distance.argsort()
classCount={}
for i in range(k):
votelabel = labels[sortedDistIndices[i]]
classCount[votelabel] = classCount.get(votelabel,0) + 1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),
reverse=True)
return sortedClassCount[0][0]
if __name__ == '__main__':
#创建数据集
group,labels = createDataSet()
#绘制数据集点图
createMap(group,labels)
#创建测试集
test = [1.0,0.5]
#绘制测试点
plt.plot(test[0],test[1],'g>')
#计算测试点的类别
testClass = classify0(test,group,labels,3)
plt.text(test[0]-0.05,test[1]+0.05,'testPoint')
plt.text(test[0],test[1]-0.05,testClass)
print(testClass)
简单k-近邻
最新推荐文章于 2023-10-16 22:54:06 发布