KD树算法

与传统的KNN算法比较我感觉慢很多,我的姿势是不是不对

kd树

import numpy as np
from numpy import *

class KDNode():
    """
    KDNode
    point:该节点的样本点
    split:用于判断分割的维度(属性)
    left:左节点
    right:右节点
    """
    def __init__(self, point=None, split=None, left=None, right=None):
        self.point = point
        self.split = split
        self.left = left
        self.right = right

class KDTree():
    """
    KD树
    KDNode:kd-tree的节点
    dimensions:数据的纬度
    right:节点的右子节点
    left:节点的左子节点
    curr_axis:当前需要切分的纬度
    next_axis:下一次需要切分的纬度
    """
    def __init__(self,data=None):
        """
        采用递归的方式创建树
        """
        def createNode(split=None, data_set=None):
            """
            递归创建节点
            input:(1)split:分割的维度(2)data_set:需要分割的样本点集合
            output:KDnode
            """
            if len(data_set) == 0:
                return None # 数据集为空,作为递归的停止条件
            # 按照split对data_set进行排序,找到split维度中的中位数
            data_set = list(data_set)
            data_set.sort(key=lambda x: x[split]) # 按照split维的数据大小排序
            data_set = np.array(data_set)
            median = len(data_set) // 2 # 不用python自带的median函数,我返回的是median的位置所在的索引
            # data_set[median]就是这个节点的样本点
            # split是这个节点的分割维度
            # data_set[:median]样本节点左半部分 data_set[median-1:]
            print("------------",median)
            print('data_set[:median]',data_set[:median])
            print('data_set[median+1:]',data_set[median+1:])
            return KDNode(data_set[median],split,createNode(maxVar(data_set[:median]),data_set[:median]),createNode(maxVar(data_set[median+1:]),data_set[median+1:]))

        def maxVar(data_set=None):
            """
            计算样本集的最大方差维度
            input:data_set样本集
            output:split:最大方差的维度,作为createNode的输入值
            """
            if len(data_set)==0:
                return 0
            print("======",len(data_set))
            data_mean = np.mean(data_set,0) # 按照列求平均值
            print(data_mean)
            mean_differ = data_set - data_mean # 求均值差
            data_var = np.sum(mean_differ ** 2, axis=0)/len(data_set) # 求方差,差反映数据的分散特征,方差的数值越大,那么数据的分散程度越大
            re = np.where(data_var == np.max(data_var)) # 寻找方差最大的位置
            print("re:",re)
            return re[0][0] # 方差最大的维数
        # print(data)
        self.root = createNode(maxVar(data),data)

def computeDist(pt1,pt2):
    """
    计算两个点之间的距离
    点的类型是N维的
    """
    sum = 0.0
    for i in range(len(pt1)):
        sum = sum + (pt1[i] - pt2[i]) ** 2
    return np.math.sqrt(sum)

def preOrder(root):
    """
    前序遍历KD树
    """
    print(root.point)
    if root.left:
        preOrder(root.left)
    if root.right:
        preOrder(root.right)

def updateNN(min_dist_array=None, tmp_dist=0.0, NN=None, tmp_point=None, k=1):
    """
    更新近邻点和对应的最小距离的集合
    min_dist_array为最小距离的集合
    NN为邻近点的集合
    tmp_dist和tmp_point分别是需要更新到min_dist_array,NN里的近邻点和距离
    """
    # 如果距离更小就更新min_dist_array
    if tmp_dist <= np.min(min_dist_array):
        # 删除最大距离和对应的节点
        for i in range(k-1,0,-1):
            min_dist_array[i] = min_dist_array[i-1]
            NN[i] = NN[i-1]

        min_dist_array[0] = tmp_dist
        NN[0] = tmp_point
        return NN,min_dist_array
    for i in range(k) :
        if (min_dist_array[i] <= tmp_dist) and (min_dist_array[i+1] >= tmp_dist) :
            #tmp_dist在min_dist_array的第i位和第i+1位之间,则插入到i和i+1之间,并把最后一位给剔除掉
            for j in range(k-1,i,-1) : #range反向取值
                min_dist_array[j] = min_dist_array[j-1]
                NN[j] = NN[j-1]
            min_dist_array[i+1] = tmp_dist
            NN[i+1] = tmp_point
            break
    return NN,min_dist_array

def searchKDTree(KDTree=None, target_point=None, k=1):
    """
    搜索KD树
    input:KDTree:kd树;target_point:目标点;k:距离目标点最近的k个点的k值
    output:k_arrayList,距离目标点最近的k个点的集合数组
    """
    if k==0 : return None

    tempNode = KDTree.root # 从更节点出发
    NN = [tempNode.point] * k #定义最邻近点集合,k个元素,按照距离远近,由近到远。初始化为k个根节点
    min_dist_array = [float("inf")] * k#定义近邻点与目标点距离的集合.初始化为无穷大
    nodeList = []

    def buildSearchPath(tempNode=None, nodeList=None,min_dist_array=None,NN=None,target_point=None):
        """
        此方法是用来建立以tempNode为根节点,以下所有节点的查找路径,并将它们存放到nodeList中
        nodeList为一系列节点的顺序组合,按此先后顺序搜索最邻近点
        tempNode为"根节点",即以它为根节点,查找它以下所有的节点(空间)
        """
        while tempNode:
            nodeList.append(tempNode)
            split = tempNode.split
            point = tempNode.point
            tmp_dist = computeDist(point,target_point)
            if tmp_dist < np.max(min_dist_array):
                NN,min_dist_array = updateNN(min_dist_array,tmp_dist,NN,point,k)# 更新最小距离和最近邻近点
            if target_point[split] <= point[split]:#如果目标点当前维的值小于等于切分点的当前维坐标值,移动到左节点
                tempNode = tempNode.left
            else:
                tempNode.right
        return NN,min_dist_array


    # 建立查找路径
    NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point)
    # 回溯查找
    while nodeList:
        back_node = nodeList.pop()
        split = back_node.split
        point = back_node.point
        #判断是否需要进入父节点搜素
        #如果当前纬度,目标点减实例点大于最小距离,就没必要进入父节点搜素了
        #因为目标点到切割超平面的距离很大,那邻近点肯定不在那个切割的空间里,即没必要进入那个空间搜素了
        if not abs(target_point[split] - point[split]) >= np.max(min_dist_array):
            if target_point[split] <= point[split]: # 在右侧
                tempNode = back_node.right
            else:
                tempNode = back_node.left # 在左侧
            if tempNode:
                NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point)
    return NN,min_dist_array

def classify0(inX, dataSet, labels, k):
    '''
    k近邻算法的分类器
    input:
    inX:目标点
    dataSet:训练点集合
    labels:训练点对应的标签
    k:k值
    这个方法的目的:已知训练点dataSet和对应的标签labels,确定目标点inX对应的labels
    ''' 
    kd = KDTree(dataSet)#构建dataSet的kd树
    NN,min_dist_array = searchKDTree(kd, inX, k)#搜索kd树,返回最近的k个点的集合NN,和对应的距离min_dist_array
    dataSet = dataSet.tolist()
    voteIlabels = []
    #多数投票法则确定inX的标签,为防止边界处分类不准的情况,以距离的倒数为权重,即距离越近,权重越大,越该认为inX是属于该类
    for i in range(k) :
        #找到每个近邻点对应的标签
        nni = list(NN[i])
        voteIlabels.append(labels[dataSet.index(nni)])

#     #开始记数,加权重的方法
#     uniques = np.unique(voteIlabels)
#     counts = [0.0] * len(uniques)
#     for i in range(len(voteIlabels)) :
#         for j in range(len(uniques)) :
#             if voteIlabels[i] == uniques[j] :
#                 counts[j] = counts[j] + uniques[j] / min_dist_array[i] #权重为距离的倒数
#                 break
    #开始记数,不加权重的方法
    uniques, counts = np.unique(voteIlabels, return_counts=True)
    return uniques[np.argmax(counts)]

# 处理文件数据
def file2matrix(filename):
    fr = open(filename) # 打开文件
    arrayOlines = fr.readlines() #读取文件
    numbersOfLines = len(arrayOlines) # 文件有多少行
    returnMat = zeros((numbersOfLines,3)) # 创建0矩阵
    classLabelVector = [] # 标签集合
    index = 0
    for line in arrayOlines:
        line = line.strip()#移除字符串头尾的空格
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3] # 取前三个数据然后给切片赋值
        classLabelVector.append(int(listFromLine[-1])) # 最后一个是标签
        index += 1
    return returnMat,classLabelVector

# 归一化特征值
def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals,(m,1))
    normDataSet = normDataSet/tile(ranges,(m,1))
    return normDataSet,ranges,minVals


def datingClassTest():
    hoRatio = 0.1 # 测试样本的比例
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') # 载入数据
    normMat,ranges,minVals = autoNorm(datingDataMat) # 归一化处理
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio) # 获取测试样本
    errorCount = 0.0
    print(type(datingDataMat))
    print(type(datingLabels))
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
        if (classifierResult != datingLabels[i]): errorCount += 1.0
    print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
    print(errorCount)

if __name__ == "__main__":
    # test()
    # test2()
    datingClassTest()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值