KNN的基本思路就是首先给出一组数据,并且给出每个数据所属的类别,然后再来一个新的数据,将这个数据与给出的每一个数据都计算一下距离,然后选一个K值,选取前k个与这个新数据最近的点,将这些点中属于某个类别最多的点的类别作为ie这个新数据的类别,具体代码如下:
import numpy as np
import matplotlib.pyplot as plt
#创建数据,让x,y都服从均值为3,方差为5的正态分布
def createdata(n):
x = np.random.normal(3,5,n)
y = np.random.normal(3,5,n)
lable = np.random.randint(0,2,n)#随机生成标签,这里用2分类
return x,y,lable
#将生成的数据核新的数据显示出来
def showdata(x,y,lable,data,Lable):
for i in range(len(lable)):
if lable[i] == 0:
plt.scatter(x[i],y[i],color='r')
else:
plt.scatter(x[i], y[i], color='b')
if Lable == 0:
color = 'r'
else:
color = 'b'
while True:
plt.scatter(data[0], data[1], color=color)
plt.pause(0.5)
plt.scatter(data[0], data[1], color='g')
plt.pause(0.5)
plt.show()
#计算新数据与每个数据的距离并根据距离来判断新数据属于哪一个类别
def classification(x,y,lable,data,k):
X,Y = data
dicedence = []
Index = []
for i in range(len(lable)):
dicedence.append(np.sqrt((x[i]-X)**2+(y[i]-Y)**2))
d_sort = sorted(dicedence)
for i in range(k):
Index.append(dicedence.index(d_sort[i]))
class_0 = 0
class_1 = 1
for i in range(k):
if lable[Index[i]] == 0:
class_0 += 1
else:
class_1 += 1
if class_0 > class_1:
Lable = 0
else:
Lable = 1
return Lable
x,y,lable = createdata(100)
#创建新数据的两种方式
# data = np.random.normal(3,5,2)
data = []
X = input('请输入x:')
Y = input('请输入y:')
data.append(float(X))
data.append(float(Y))
Lable = classification(x,y,lable,np.array(data),5)
showdata(x,y,lable,data,Lable)
通用版本(更好):
import numpy as np
import matplotlib.pyplot as plt
import operator
def createDataSet():
group = np.array(np.random.normal(3,5,200)).reshape(100,2)
labels = np.where(np.array(np.random.randint(0,2,100)),'A','B')
return group,labels
def classify(dataSet,labels,inX,k):
d = dataSet-inX
dicedence = np.sum(d**2,axis=1)**0.5
Dis = dicedence.argsort()
classCount = {}#存放最终投票的结果
for i in range(k):
voteclass = labels[Dis[i]]
classCount[voteclass] = classCount.get(voteclass,0)+1
return sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)[0][0]
def show_data(dataSet,labels):
for i in range(len(labels)):
if labels[i] == 'A':
if i == len(labels) - 1:
plt.scatter(dataSet[i][0], dataSet[i][1], c='r',marker='x')
else:
plt.scatter(dataSet[i][0],dataSet[i][1],c='r')
elif labels[i] == 'B':
if i == len(labels) - 1:
plt.scatter(dataSet[i][0], dataSet[i][1], c='b',marker='x')
else:
plt.scatter(dataSet[i][0],dataSet[i][1],c='b')
plt.show()
if __name__ == '__main__':
#导入数据
dataSet,labels = createDataSet()
#新数据
inX = np.random.normal(5,3,2)
Class = classify(dataSet,labels,inX,3)
print('新数据属于{}类别'.format(Class))
dataSet = np.vstack((dataSet,inX))
labels = np.append(labels,Class)
show_data(dataSet,labels)