下图为KNN算法的伪代码,截自《机器学习实战》 P19:
import numpy as np
import operator
from os import listdir
def CreatDataSet():
group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
def classify0(inX, dataSet, labels, k):
"""
:type inX: array[float],新的待分类点
:type dataSet: array[array[float]],训练样本
:type labels: List[str]
:type k: int
:rtype: str,KNN中出现频率最高的分类
"""
# print(type(inX))
# print(type(dataSet))
# print(type(labels[0]))
# print(type(k))
# 图中(1)(2)(3)
dist = [np.sqrt(np.sum(np.square(inX - dataSet[i]))) for i in range(len(dataSet))]
k_ans = [x[0] for x in sorted(enumerate(dist), key=lambda x: x[1])][:k]
# 图中(4)
class_cnt = {}
for item in k_ans:
label = labels[item]
class_cnt[label] = class_cnt.get(label, 0)+1
# 图中(5)
MAX = 0
ans = '-'
for k, v in class_cnt.items():
if v > MAX:
MAX = v
ans = k
return ans
# sample:
group, labels = CreatDataSet()
inX = np.array([0, 0])
k = 3
print(classify0(inX, group, labels, 3))