K近邻算法的KD树实现

#K近邻算法的KD树实现
#lichunyu-2020.6.3

import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt

class Node:
    def __init__(self):
        self.left = None
        self.right = None
        self.value = [] #vector

class Neighbour:
    def __init__(self, k):
        self.k = k
        self.nk = [(None, float('inf'))] * self.k
        
    def getMaxDist(self):
        return max(self.nk, key=lambda elem: elem[1])[1]

    def update_max(self, item): #用 item 替换 nk 中距离最大的元素
        for i in range(self.k):
            if self.nk[i][1] == self.getMaxDist():
                self.nk[i] = item
                break

    def show(self):
        self.nk.sort(key=lambda elem: elem[1])
        for i in range(self.k):
            # print(self.nk[i][0].value, self.nk[i][1])
            plt.plot(self.nk[i][0].value[0], self.nk[i][0].value[1], 'rx', c='g', label='nk')

class KDTree:
    def __init__(self, data, neighbour, p = 2):
        self.root = None
        self.dimension = len(data[0]) - 1 #x[0] = [x1, x2, y]
        self.root = self.construct(data, 0)
        self.p = p  #距离变量
        self.neighbour = neighbour # k个邻域

    def construct(self, data, cur_d): # cur_d -> 当前坐标维度
        if(len(data) == 0):
            return None

        data = data[data[:, cur_d].argsort()]    # 按照当前维度的坐标排序
        mid = len(data) // 2
        node = Node()
        node.value = data[mid]
        next_d = (cur_d + 1) % self.dimension
        node.left = self.construct(data[0 : mid, :], next_d)
        node.right = self.construct(data[mid + 1 :, :], next_d)

        return node

    def search(self, node, pos, cur_d = 0): # kd-tree 查找最近邻
        if pos[cur_d] <= node.value[cur_d]:
            nearer_node = node.left
            further_node = node.right
        else:
            nearer_node = node.right
            further_node = node.left

        next_d = (cur_d + 1) % self.dimension
        if nearer_node:
            self.search(nearer_node, pos, next_d)
        
        #当前 node 与 pos 的距离 ---> 是否更近
        distance = self._Lp(node.value[:-1], pos, self.p)
        if distance < self.neighbour.getMaxDist():
            self.neighbour.update_max((node, distance))

        #另一个子节点的区域是否与超球体相交 $$超球体以neighbour中最大距离为半径
        if further_node and (further_node.value[cur_d] - pos[cur_d] < self.neighbour.getMaxDist()): #如果相交
            self.search(further_node, pos, next_d)  #在另一个结点的区域内找更近的

    def _Lp(self, x1, x2, p):
        sum = 0
        for i in range(len(x1)):
            sum += math.pow(abs(x1[i] - x2[i]), p)
        return math.pow(sum, 1 / p)

class KNN:
    def __init__(self, data, k = 1, p = 2):
        self.neighbour = Neighbour(k)
        self.kdTree = KDTree(data, self.neighbour, p)
        
    def predict(self, pos):
        self.kdTree.search(self.kdTree.root, pos)
        return self.judge(self.kdTree.neighbour.nk)

    def judge(self, nk):
        dict_class_times = {}
        for each in nk: #统计k近邻 中每个 class 出现次数
            belong = each[0].value[-1]
            if belong in dict_class_times:    #y[index] --> class
                dict_class_times[belong] += 1 
            else:
                dict_class_times[belong] = 1
        
        return max(dict_class_times, key=lambda elem: dict_class_times[elem])


def test():
    #data
    from sklearn.datasets import load_iris
    iris = load_iris()
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    df['label'] = iris.target
    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
    data = np.array(df.iloc[:100, [0, 1, -1]])
    
    # data = np.array([[2,3,0],[5,4,0],[9,6,0],[4,7,0],[8,1,0],[7,2,0]])
    # plt.scatter(data[:, 0], data[:, 1], c='y', label='1')
    
    knn = KNN(data, k = 10, p = 2)
    pos = [5.1, 2.8]
    belong = knn.predict(pos)
    print(pos, "belongs to ", belong)

    knn.kdTree.neighbour.show()
    plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], c='b', label='0')
    plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], c='y', label='1')
    plt.plot(pos[0], pos[1], 'b*', label='test_point')
    plt.xlabel('sepal length')
    plt.ylabel('sepal width')
    plt.show()

if __name__ == "__main__":
    test()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值