k k k近邻
前面介绍了 k k k近邻的相关原理,这里我们举一个相关的示例(python 3)。
数据来源
此次我们使用几个自己定义的简单样本进行一个验证(我们使用 k d kd kd树)。
程序实现
- 引入相关模块
import numpy as np
- 构建结点的类
###定义结点类
class Node:
def __init__(self,data,lchild=None,rchild=None):
self.data=data #根节点
self.lchild=lchild #左子树
self.rchild=rchild #右子树
- 构建 k d kd kd树的类
class KdTree:
def __init__(self):
self.kdTree=None
#定义构建kd树的方法
def create(self,dataset,depth):
if (len(dataset)>0):
m,n=np.shape(dataset) #获得样本的行、列数
midIndex=int(m/2) #中间数的索引位置
axis = depth%n #确定分类的属性
sortedDataset=self.sort(dataset=set,axis) #进行排序
node=Node(sortedDataset[midIndex]) #根据中间点创建树结构
leftDataset=sortedDataset[:midIndex] #左边保存小样本
rightDataset=sortedDataset[midIndex:] #右边保存大样本
#递归构建左右子树
nold.lchild=self.create(leftDataset,depth+1)
nold.rchild=self.create(rightDataset,depth+1)
return node
else:
return None
#定义排序用到的算法(冒泡)
def sort(self,dataset,axis):
sortDataset=dataset[:]
m,n=np.shape(sortDataset) #获取行列
# 进行排序,每次能找出一个元素的位置(最大值从大到小位置逐步确定)
for i in range(m):
for j in range(0,m-i-1):
if (sortDataset[j][axis]>sortDataset[j+1][axis]):
temp=sortDataset[j]
sortDataset[j]=sortDataset[j+1]
sortDataset[j+1]=temp
return sortDataset
def preOrder(self,node):
if node!=None:
print("tttt->%s"%node.data)
self.preOrder(node.lchild)
self.preOrder(node.rchild)
#定义搜索方法
def search(self,tree,x):
self.nearestPoint=None #保存最近点
self.nearestValue=0 #保存最近点的值
#递归搜索
def travel(node,depth=0):
if node!=None:
n=len(x) #特征数
axis=depth%n #计算轴(属性)
#从根节点找到包含目标结点的结点
if x[axis]<node.data[axis]:
travel(node.lchild,depth+1)
else:
travel(node.rchild,depth+1)
#反方向回溯
distNodeAndX=self.dist(x,node.data) #计算x与节点之间的距离
if (self.nearestPoint==None):
self.nearestPoint=node.data
self.nearestValue=distNodeAndX
elif (self.nearestValue>distNodeAndX):
self.nearestPoint=node.data
self.nearestValue=distNodeAndX
#判断是否需要去另一个子节点搜索
if (abs(x[axis]-node.data[axis])<=self.nearestValue):
if x[axis]<node.data[axis]:
travel(node.rchild,depth+1)
else:
travel(node.lchild,depth+1)
travel(tree)
return self.nearestPoint
#定义距离计算
def dist(self,x1,x2):
return ((np.array(x1)-np.array(x2))**2).sum()**0.5
- 验证
dataSet = [[2, 3],
[5, 4],
[9, 6],
[4, 7],
[8, 1],
[7, 2]]
x = [5, 3]
kdtree = KdTree()
tree = kdtree.create(dataSet, 0)
kdtree.preOrder(tree)
print(kdtree.search(tree, x))```
完整代码
https://github.com/canshang/-/blob/master/感知机.ipynb
参考文献
《机器学习实战》
https://blog.csdn.net/tudaodiaozhale/article/details/77327003