1、KD树的构造(ongoing)
#DBSCAN inspects abnormal sample
import numpy as np
from heapq import heappush, heappop, nsmallest, heappushpop
from scipy.spatial import KDTree
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def getNumLeafs(mytree):
numLeafs = 0
for key in (mytree.less, mytree.greater):
if type(key).__name__=='innerNode':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(key)
else: numLeafs +=1
return numLeafs
def getTreeDepth(mytree):
maxDepth = 0
for key in (mytree.less, mytree.greater):
if type(key).__name__=='innerNode':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(key)
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
#firstStr = list(myTree.keys())[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(myTree.pivot_idx, cntrPt, parentPt, decisionNode)
#secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in (myTree.less, myTree.greater):
if type(key).__name__=='innerNode':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(key, cntrPt, key.split_dim) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
if key:
plotNode(key.pivot_idx, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
#plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
def minkowski_distance_p(x, y, p):
d = x - y
return np.power(np.sum(d**p), 1/p)
class KdTree:
def __init__(self, data):
self.data = np.asarray(data)
self.m, self.n = self.data.shape
self.root = self.__build(np.arange(self.m), 0, None)
self.current_node = None
class innerNode:
def __init__(self, split_dim, pivot, pivot_idx, height, less, greater, parent):
self.split_dim = split_dim
self.pivot = pivot
self.pivot_idx = pivot_idx
self.height = height
self.less = less
self.greater = greater
self.parent = parent
self.visit = 0
class leafNode(innerNode):
def __init__(self, split_dim, pivot, pivot_idx, height, less, greater, parent):
super().__init__(split_dim, pivot, pivot_idx, height, less, greater, parent)
def __build(self, idx, height, parent):
split_dim = height % self.n
#print(self.data[idx])
if len(idx) == 0:
return None
if len(idx) == 1:
return KdTree.leafNode(split_dim, self.data[idx[0]][split_dim], idx[0], height, None, None, parent)
data = self.data[idx]
data = data[:, split_dim]
if len(data)%2 == 0:
pivot = np.median(np.append(data, data[0]))
else:
pivot = np.median(data)
pivot_idx = idx[np.argwhere(data==pivot)[0][0]]
less_idx = np.nonzero(data<pivot)[0]
greater_idx = np.nonzero(data>pivot)[0]
#note:下面函数应该传入idx[less_idx] idx[greater_idx],而不是less_idx greater_idx
p = KdTree.innerNode(split_dim, pivot, pivot_idx, height, None, None, parent)
p.less = self.__build(idx[less_idx], height + 1, p)
p.greater = self.__build(idx[greater_idx], height + 1, p)
return p
"""
# 递归awesome
return KDTree.KDNode(split_dim, pivot, pivot_idx, height,
self.__build(idx[less_idx], height + 1, self.KDNode),
self.__build(idx[greater_idx], height + 1, self.KDNode), parent)
"""
def __findLeaf(self, x, root):
node = root
while node:
leaf = node
if node.pivot < x[node.split_dim]:
node = node.less
else:
node = node.greater
return leaf
def __push(self, neighbors, x, k, node, p):
d = minkowski_distance_p(x[np.newaxis,:], self.data[node.pivot_idx], p)
if len(neighbors) < k:
heappush(neighbors, (-d, node.pivot_idx))
else:
heappushpop(neighbors, (-d, node.pivot_idx))
def __query(self, x , neighbors, innernode, label, k, p):
innernode.visit += 1
if label[innernode.pivot_idx]:
return
while innernode:
if not label[innernode.pivot_idx]:
label[innernode.pivot_idx] = 1
self.__push(neighbors, x, k, innernode, p)
largest_in_neighbors = -nsmallest(1, neighbors)[0][0]
dis_far_split_axis = np.abs(x[innernode.split_dim]-innernode.pivot)
if dis_far_split_axis < largest_in_neighbors or len(neighbors)<k:
if innernode.less and not label[innernode.less.pivot_idx]:
leaf = self.__findLeaf(x, innernode.less)
self.__query(x, neighbors, leaf, label, k, p)
if innernode.greater and not label[innernode.greater.pivot_idx]:
leaf = self.__findLeaf(x, innernode.greater)
self.__query(x, neighbors, leaf, label, k, p)
innernode = innernode.parent
return neighbors
def query(self, x, k=1, p=2):
neighbors = []
node = self.root
label = np.zeros(len(self.data), dtype=np.int16)
leaf = self.__findLeaf(x, node)
label[leaf.pivot_idx] = 1
self.__push(neighbors, x, k, leaf, p)
self.__query(x , neighbors, leaf.parent, label, k, p)
return neighbors
def inorder(self, root):
if root is None:
return
self.inorder(root.less)
print(root.height)
self.inorder(root.greater)
class DBScan:
def __init__(self, epsilon, minPts):
self.epsilon = epsilon
self.minPts = minPts
if __name__ == "__main__":
data = np.random.randn(700).reshape((100, -1))
#data = np.array([[ 0.74728798, 0.81022863, -0.19179337, 0.878292 ],
# [-2.13781247, 0.91024753, 0.09538944, -0.29745797],
# [ 0.45066661, -0.27623008, 0.15601932, -1.97192213],
# [ 0.79890978, 2.01713301, -0.00664947, -0.37733724],
# [-0.75239458, 0.56911767, 1.31537443, -0.6950948 ]])
#print(data)
#print(data)
x = np.random.randn(7)
print(x)
kd = KdTree(data)
kd.inorder(kd.root)
#kd1 = KDTree(data)
#print("current:", getNumLeafs(kd.root))
createPlot(kd.root)