KNN算法(基于KD-Tree)

输入数据可以自己调,设置要查的point和k值后返回最邻近的k个point

import numpy as np


class KD_Node:
    def __init__(self, point=None, split=None, leftNode=None, rightNode=None):
        """
        :param point:数据点
        :param split:分割维度
        :param LL:左儿子
        :param RR:右儿子
        """
        self.point = point
        self.split = split
        self.leftNode = leftNode
        self.rightNode = rightNode


def setData():
    group = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
    return group


def createKDTree(group):
    """
    构造KD树
    :param group: 数据集
    :return:
    """
    length = len(group)
    if length == 0:
        return
    dimension = len(group[0])
    # 维度
    split = 0
    # 最大方差
    max_var = 0
    # 获得方差最大的维度
    for i in range(dimension):
        temp = []
        for data in group:
            temp.append(data[i])
        var = cal_variance(temp)
        if var > max_var:
            max_var = var
            split = i
    group.sort(key=lambda x: x[split])
    point = group[length // 2]
    root = KD_Node(point, split)
    root.leftNode = createKDTree(group[0: (length // 2)])
    # 注意右儿子要加一
    root.rightNode = createKDTree(group[(length // 2 + 1): length])
    return root


def cal_variance(dataList):
    """
    计算数据集的方差,D(x) = E(X^2) - [E(X)]^2
    :param dataList:某一维度下的所有数据
    :return:
    """
    length = len(dataList)
    array = np.array(dataList)
    sum1 = array.sum()
    e1 = sum1 / length
    e2 = (array * array).sum() / length
    return e2 - e1 * e1


def findKN(point, root, k):
    """
    kd树搜索函数,搜索最近的点
    :param point:要查找的点
    :param root:kd树的树根
    :param k:返回邻近点的个数
    :return:
    """
    if k < 1:
        return
    ret = [root.point]
    max_dist = cal_dist(point, ret[-1])
    nodeList = []
    temp = root
    while temp:
        nodeList.append(temp)
        dist = cal_dist(point, temp.point)
        if len(ret) < k:
            ret.append(temp.point)
            sortP(ret, point)
            if dist > max_dist:
                max_dist = dist
        elif dist < max_dist and len(ret) >= k:
            # 去除最大
            ret.pop()
            ret.append(temp.point)
            sortP(ret, point)
            max_dist = cal_dist(point, ret[-1])
        ss = temp.split
        if point[ss] < temp.point[ss]:
            temp = temp.leftNode
        else:
            temp = temp.rightNode
    while nodeList:
        bac_point = nodeList.pop()
        ss = bac_point.split
        # 判断是否要进入父节点的另一个子空间进行搜索
        # 并不是判断距离就要进去,只要空间中有那个圈就要进去
        if abs(point[ss] - bac_point.point[ss]) < max_dist:
            if point[ss] >= bac_point.point[ss]:
                temp = bac_point.leftNode
            else:
                temp = bac_point.rightNode

            if temp:
                nodeList.append(temp)
                dist = cal_dist(temp.point, point)
                if len(ret) < k:
                    ret.append(temp.point)
                    sortP(ret, point)
                    if dist > max_dist:
                        max_dist = dist
                elif dist < max_dist and len(ret) >= k:
                    ret.pop()
                    ret.append(temp.point)
                    sortP(ret, point)
                    max_dist = cal_dist(point, ret[-1])
    return ret


def sortP(group, point):
    for i in range(len(group) - 1):
        for j in range(i + 1, len(group)):
            if cal_dist(group[j], point) < cal_dist(group[i], point):
                temp = group[j]
                group[j] = group[i]
                group[i] = temp
    return group


def cal_dist(point1, point2):
    """
    计算两个点的欧氏距离
    :param point1:
    :param point2:
    :return:
    """
    ret = 0
    for i in range(len(point1)):
        ret += (point1[i] - point2[i]) ** 2
    return ret ** (1 / 2)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值