DBSCAN

1、KD树的构造(ongoing)

#DBSCAN inspects abnormal sample
import numpy as np
from heapq import heappush, heappop, nsmallest, heappushpop
from scipy.spatial import KDTree 
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(mytree):
    numLeafs = 0
    for key in (mytree.less, mytree.greater):
        if type(key).__name__=='innerNode':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(key)
        else:   numLeafs +=1
        
    return numLeafs

def getTreeDepth(mytree):
    maxDepth = 0
    for key in (mytree.less, mytree.greater):
        if type(key).__name__=='innerNode':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(key)
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    #firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    #plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(myTree.pivot_idx, cntrPt, parentPt, decisionNode)
    #secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in (myTree.less, myTree.greater):
        if type(key).__name__=='innerNode':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(key, cntrPt, key.split_dim)        #recursion
        else:   #it's a leaf node print the leaf node
            
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            if key:
                plotNode(key.pivot_idx, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            #plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

def minkowski_distance_p(x, y, p):
    d = x - y
    return np.power(np.sum(d**p), 1/p)
    
class KdTree:
    def __init__(self, data):
        self.data = np.asarray(data)
        self.m, self.n = self.data.shape
        self.root = self.__build(np.arange(self.m), 0, None)
        self.current_node = None
        
    class innerNode:
        def __init__(self, split_dim, pivot, pivot_idx, height, less, greater, parent):
            self.split_dim = split_dim
            self.pivot = pivot
            self.pivot_idx = pivot_idx
            self.height = height
            self.less = less
            self.greater = greater
            self.parent = parent
            self.visit = 0
    
    class leafNode(innerNode):
        def __init__(self, split_dim, pivot, pivot_idx, height, less, greater, parent):
            super().__init__(split_dim, pivot, pivot_idx, height, less, greater, parent)

    def __build(self, idx, height, parent):
        split_dim = height % self.n
        #print(self.data[idx])
        if len(idx) == 0:
            return None
        if len(idx) == 1:
            return KdTree.leafNode(split_dim, self.data[idx[0]][split_dim], idx[0], height, None, None, parent)
        
        data = self.data[idx]
        data = data[:, split_dim]
        
        if len(data)%2 == 0:
            pivot = np.median(np.append(data, data[0]))
        else:
            pivot = np.median(data)
        pivot_idx = idx[np.argwhere(data==pivot)[0][0]]
        less_idx = np.nonzero(data<pivot)[0]
        greater_idx = np.nonzero(data>pivot)[0]
        #note:下面函数应该传入idx[less_idx] idx[greater_idx],而不是less_idx greater_idx
        p = KdTree.innerNode(split_dim, pivot, pivot_idx, height, None, None, parent)
        p.less = self.__build(idx[less_idx], height + 1, p)
        p.greater = self.__build(idx[greater_idx], height + 1, p)
        return p
    
        """
        # 递归awesome
        return KDTree.KDNode(split_dim, pivot, pivot_idx, height,  
                             self.__build(idx[less_idx], height + 1, self.KDNode), 
                             self.__build(idx[greater_idx], height + 1, self.KDNode), parent)
        """
        
    def __findLeaf(self, x, root):
        node = root
        while node:
            leaf = node
            if node.pivot < x[node.split_dim]:
                node = node.less
            else:
                node = node.greater
        
        return leaf
    
    def __push(self, neighbors, x, k, node, p):
        d = minkowski_distance_p(x[np.newaxis,:], self.data[node.pivot_idx], p)
        if len(neighbors) < k:
            heappush(neighbors, (-d, node.pivot_idx))
        else:
            heappushpop(neighbors, (-d, node.pivot_idx))

    def __query(self, x , neighbors, innernode, label, k, p):
        innernode.visit += 1
        if label[innernode.pivot_idx]:
            return
        while innernode:
            if not label[innernode.pivot_idx]:
                label[innernode.pivot_idx] = 1
                self.__push(neighbors, x, k, innernode, p)
                largest_in_neighbors = -nsmallest(1, neighbors)[0][0]
                dis_far_split_axis = np.abs(x[innernode.split_dim]-innernode.pivot)
                if dis_far_split_axis < largest_in_neighbors or len(neighbors)<k:
                    if innernode.less and not label[innernode.less.pivot_idx]:
                        leaf = self.__findLeaf(x, innernode.less)
                        self.__query(x, neighbors, leaf, label, k, p)
                    if innernode.greater and not label[innernode.greater.pivot_idx]:
                        leaf = self.__findLeaf(x, innernode.greater)
                        self.__query(x, neighbors, leaf, label, k, p)
            
            innernode = innernode.parent
    
        return neighbors
    
    def query(self, x, k=1, p=2):
        neighbors = []
        node = self.root
        label = np.zeros(len(self.data), dtype=np.int16)
        leaf = self.__findLeaf(x, node)
        label[leaf.pivot_idx] = 1
        self.__push(neighbors, x, k, leaf, p)
        self.__query(x , neighbors, leaf.parent, label, k, p)
        return neighbors
        
    def inorder(self, root):
        if root is None:
            return
        self.inorder(root.less)
        print(root.height)
        self.inorder(root.greater)

class DBScan:
    def __init__(self, epsilon, minPts):
        self.epsilon = epsilon
        self.minPts = minPts
        
if __name__ == "__main__":
    data = np.random.randn(700).reshape((100, -1))
    #data = np.array([[ 0.74728798, 0.81022863, -0.19179337, 0.878292  ],
    #                 [-2.13781247, 0.91024753, 0.09538944, -0.29745797],
    #                 [ 0.45066661, -0.27623008, 0.15601932, -1.97192213],
    #                 [ 0.79890978, 2.01713301, -0.00664947, -0.37733724],
    #                 [-0.75239458, 0.56911767, 1.31537443, -0.6950948 ]])
    #print(data)
    #print(data)
    x = np.random.randn(7)
    print(x)
    kd = KdTree(data)
    kd.inorder(kd.root)
    #kd1 = KDTree(data)
    #print("current:", getNumLeafs(kd.root))
    createPlot(kd.root)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值