K近邻分类算法
项目链接:https://github.com/Wchenguang/gglearn/blob/master/KNNClassifier-todo/李航机器学习讲解/KNNClassifier.ipynb
公式笔记
-
Lp距离公式
L p ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ p ) 1 p L_{p}\left(x_{i}, x_{j}\right)=\left(\sum_{l=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|^{p}\right)^{\frac{1}{p}} Lp(xi,xj)=(l=1∑n∣∣∣xi(l)−xj(l)∣∣∣p)p1 -
$ p=2 $ 为欧氏距离
L 2 ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ 2 ) 1 2 L_{2}\left(x_{i}, x_{j}\right)=\left(\sum_{l=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|^{2}\right)^{\frac{1}{2}} L2(xi,xj)=(l=1∑n∣∣∣xi(l)−xj(l)∣∣∣2)21 -
$ p=1 $ 为曼哈顿距离
L 1 ( x i , x j ) = ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ L_{1}\left(x_{i}, x_{j}\right)=\sum_{l=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right| L1(xi,xj)=l=1∑n∣∣∣xi(l)−xj(l)∣∣∣
'''
kd树构造
'''
class kdTreeNode:
def __init__(self, x, index, target_dim, parent):
'''
所有节点共用x
利用索引来表示该节点所覆盖的数据
'''
self.x = x
self.index = index
self.target_dim = target_dim
self.parent = parent
self.left = None
self.right = None
self.target_index = None
def set_left(self, left):
self.left = left
def set_right(self, right):
self.right = right
def set_target_index(self, target_index):
self.target_index = target_index
class kdTree:
def __init__(self, x):
self.dim = len(x[0])
self.x = x.copy()
self.tree_done = False
def _fit(self, node):
if(node == None):
return
left_index, mid_index, right_index = self._get_sorted_index(node.index, node.target_dim)
node.set_target_index(mid_index)
#print(node.x[mid_index])
if(len(left_index) != 0):
#print('left')
node.set_left(kdTreeNode(self.x, left_index, (node.target_dim + 1)%self.dim, node))
self._fit(node.left)
if(len(right_index) != 0):
#print('right')
node.set_right(kdTreeNode(self.x, right_index, (node.target_dim + 1)%self.dim, node))
self._fit(node.right)
def fit(self):
self.root = kdTreeNode(self.x, np.arange(len(self.x)), 0, None)
self._fit(self.root)
return self
def _get_sorted_index(self, index_list, dim):
temp_val = np.hstack((self.x[index_list], np.array([index_list]).T))
mid = len(temp_val) // 2
temp_val = temp_val[temp_val[:,dim].argsort()]
return temp_val[:mid][:, -1], temp_val[mid][-1], temp_val[mid+1:][:, -1]