kd树 python实现_基于kd树的k近邻算法 python实现

importnumpy as npclassbinaryTreeNode():def __init__(self,data=None,left=None,right=None,split=None):

self.data=data

self.left=left

self.right=right

self.split=splitdefgetdata(self):returnself.datadefgetleft(self):returnself.leftdefgetright(self):returnself.rightdefgetsplit(self):returnself.splitclassKNNClassfier(object):def __init__(self, k=1, distance='euc'):

self.k=k

self.distance=distance

self.root=Nonedefgetroot(self):returnself.rootdefkd_tree(self,train_X,train_Y):'''构造kd树'''

if len(train_X)==0:returnNoneif len(train_X)==1:returnbinaryTreeNode((train_X[0],train_Y[0]))

index= np.argmax(np.var(train_X,axis=0))

argsort=np.argsort(train_X[:,index])

left= self.kd_tree(train_X[argsort[0:len(argsort)//2],:],train_Y[argsort[0:len(argsort)//2]])

right= self.kd_tree(train_X[argsort[len(argsort)//2+1: ],:],train_Y[argsort[len(argsort)//2+1: ]])

root= binaryTreeNode((train_X[argsort[len(argsort)//2],:],train_Y[argsort[len(argsort)//2]]),left,right,index)returnrootdefinOrder(self,root):'''中序遍历kd树'''

if root ==None:returnNone

self.inOrder(root.getleft())print(root.getdata())

self.inOrder(root.getright())defsearch_kd_tree(self,x,knn,root,nodelist):while len(knn)==0:if root.getleft() == None and root.getright() ==None:returnknn.append(root.getdata())if x[root.getsplit()]

nodelist.append(root.getleft())

self.search_kd_tree(x,knn,root.getleft(),nodelist)else:

nodelist.append(root.getright())

self.search_kd_tree(x,knn,root.getright(),nodelist)else:if root.getright()!=None:

nodelist.append(root.getright())

self.search_kd_tree(x,knn,root.getright(),nodelist)else:

nodelist.append(root.getleft())

self.search_kd_tree(x,knn,root.getleft(),nodelist)

dis= np.linalg.norm(x-knn[0][0],ord=2)while len(nodelist)!=0:

current=nodelist.pop()#currentdis = np.linalg.norm(x-current.getdata()[0],ord=2)

if np.linalg.norm(x-current.getdata()[0],ord=2)

knn[0]=current.getdata()if current.getleft()!=None and np.linalg.norm(x-current.getleft().getdata()[0],ord=2)

knn[0]=current.getleft().getdata()if current.getright()!=None and np.linalg.norm(x-current.getright().getdata()[0],ord=2)

knn[0]=current.getright().getdata()returnknndeffit(self,X,Y):'''X : array-like [n_samples,shape]

Y : array-like [n_samples,1]'''self.root=self.kd_tree(X,Y)defpredict(self,X):

output= np.zeros((X.shape[0],1))for i inrange(X.shape[0]):

knn=[]

knn=self.search_kd_tree(X[i,:],knn,self.root,[self.root])

labels=[]for j inrange(len(knn)):

labels.append(knn[j][1])

counts=[]#print('x:',X[i,:],'knn:',knn)

for label inlabels:

counts.append(labels.count(label))

output[i]=labels[np.argmax(counts)]returnoutputdefscore(self,X,Y):

pred=self.predict(X)

err= 0.0

for i inrange(X.shape[0]):if pred[i]!=Y[i]:

err= err+1

return 1-float(err/X.shape[0])if __name__ == '__main__':from sklearn importdatasetsimporttime#x = np.array([(2,3),(5, 4), (9, 6), (4, 7), (8, 1), (7, 2)])

#y = x[:,1]

digits =datasets.load_digits()

x=digits.data

y=digits.target

myknn_start_time=time.time()

clf= KNNClassfier(k=5)

clf.fit(x,y)print('myknn score:',clf.score(x,y))

myknn_end_time=time.time()from sklearn.neighbors importKNeighborsClassifier

sklearnknn_start_time=time.time()

clf_sklearn= KNeighborsClassifier(n_neighbors=5)

clf_sklearn.fit(x,y)print('sklearn score:',clf_sklearn.score(x,y))

sklearnknn_end_time=time.time()print('myknn uses time:',myknn_end_time-myknn_start_time)print('sklearn uses time:',sklearnknn_end_time-sklearnknn_start_time)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值