#K近邻算法的KD树实现
#lichunyu-2020.6.3
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
class Node:
def __init__(self):
self.left = None
self.right = None
self.value = [] #vector
class Neighbour:
def __init__(self, k):
self.k = k
self.nk = [(None, float('inf'))] * self.k
def getMaxDist(self):
return max(self.nk, key=lambda elem: elem[1])[1]
def update_max(self, item): #用 item 替换 nk 中距离最大的元素
for i in range(self.k):
if self.nk[i][1] == self.getMaxDist():
self.nk[i] = item
break
def show(self):
self.nk.sort(key=lambda elem: elem[1])
for i in range(self.k):
# print(self.nk[i][0].value, self.nk[i][1])
plt.plot(self.nk[i][0].value[0], self.nk[i][0].value[1], 'rx', c='g', label='nk')
class KDTree:
def __init__(self, data, neighbour, p = 2):
self.root = None
self.dimension = len(data[0]) - 1 #x[0] = [x1, x2, y]
self.root = self.construct(data, 0)
self.p = p #距离变量
self.neighbour = neighbour # k个邻域
def construct(self, data, cur_d): # cur_d -> 当前坐标维度
if(len(data) == 0):
return None
data = data[data[:, cur_d].argsort()] # 按照当前维度的坐标排序
mid = len(data) // 2
node = Node()
node.value = data[mid]
next_d = (cur_d + 1) % self.dimension
node.left = self.construct(data[0 : mid, :], next_d)
node.right = self.construct(data[mid + 1 :, :], next_d)
return node
def search(self, node, pos, cur_d = 0): # kd-tree 查找最近邻
if pos[cur_d] <= node.value[cur_d]:
nearer_node = node.left
further_node = node.right
else:
nearer_node = node.right
further_node = node.left
next_d = (cur_d + 1) % self.dimension
if nearer_node:
self.search(nearer_node, pos, next_d)
#当前 node 与 pos 的距离 ---> 是否更近
distance = self._Lp(node.value[:-1], pos, self.p)
if distance < self.neighbour.getMaxDist():
self.neighbour.update_max((node, distance))
#另一个子节点的区域是否与超球体相交 $$超球体以neighbour中最大距离为半径
if further_node and (further_node.value[cur_d] - pos[cur_d] < self.neighbour.getMaxDist()): #如果相交
self.search(further_node, pos, next_d) #在另一个结点的区域内找更近的
def _Lp(self, x1, x2, p):
sum = 0
for i in range(len(x1)):
sum += math.pow(abs(x1[i] - x2[i]), p)
return math.pow(sum, 1 / p)
class KNN:
def __init__(self, data, k = 1, p = 2):
self.neighbour = Neighbour(k)
self.kdTree = KDTree(data, self.neighbour, p)
def predict(self, pos):
self.kdTree.search(self.kdTree.root, pos)
return self.judge(self.kdTree.neighbour.nk)
def judge(self, nk):
dict_class_times = {}
for each in nk: #统计k近邻 中每个 class 出现次数
belong = each[0].value[-1]
if belong in dict_class_times: #y[index] --> class
dict_class_times[belong] += 1
else:
dict_class_times[belong] = 1
return max(dict_class_times, key=lambda elem: dict_class_times[elem])
def test():
#data
from sklearn.datasets import load_iris
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
data = np.array(df.iloc[:100, [0, 1, -1]])
# data = np.array([[2,3,0],[5,4,0],[9,6,0],[4,7,0],[8,1,0],[7,2,0]])
# plt.scatter(data[:, 0], data[:, 1], c='y', label='1')
knn = KNN(data, k = 10, p = 2)
pos = [5.1, 2.8]
belong = knn.predict(pos)
print(pos, "belongs to ", belong)
knn.kdTree.neighbour.show()
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], c='b', label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], c='y', label='1')
plt.plot(pos[0], pos[1], 'b*', label='test_point')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.show()
if __name__ == "__main__":
test()
K近邻算法的KD树实现
最新推荐文章于 2021-09-27 16:13:23 发布