一. KD树的建立
KD树算法包括三步,第一步是建树,第二步是搜索最近邻,最后一步是预测。
有二维样本6个,{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)},构建kd树的具体步骤为:
1)找到划分的特征。6个数据点在x,y维度上的数据方差分别为5.80,4.47,所以在x轴上方差更大,用第1维特征建树。
计算x 轴的方差:
Ex=(2+5+9+4+8+7) / 6 = 35/6
Sx=((2- 35/6)^2 + (5- 35/6) ^ 2 + (9 -35/6) ^2 + (4- 35/6) ^2 +(8 - 35/6) ^ 2 + (7-35/6) ^ 2) /6 = 5.80
计算y轴的方差:
Ey =(3+4+6+7+1+2) / 6 = 23/6
Sy=((3-23/6)^2 + (4-23/6) ^2 + (6-23/6) ^2+ (7-23/6) ^2 +(1-23/6) ^2 + (2 -23/6) ^2)/6 = 4.47
2)确定划分点(7,2)。根据x维上的值将数据排序,6个数据的中值(所谓中值,即中间大小的值)为7,所以划分点的数据是(7,2)。这样,该节点的分割超平面就是通过(7,2)并垂直于:划分点维度的直线x=7;
3)确定左子空间和右子空间。 分割超平面x=7将整个空间分为两部分:x<=7的部分为左子空间,包含3个节点={(2,3),(5,4),(4,7)};另一部分为右子空间,包含2个节点={(9,6),(8,1)}。
划分左边三个点:
计算x轴方差:
Ex = (2 + 5 + 4)/3 = 11/3
Sx = ((2 - 11/3) ^ 2 + (5-11/3) ^ 2 + (4 - 11/3) ^2 ) /3= 1.56
计算y 轴方差:
Ey = (3 + 4 + 7) / 3 = 14/3
Sy = ((3- 14/3) ^ 2 +(4 -14/3) ^ 2 +(7 - 14/3) ^ 2)/3 = 2.89所以y轴的方差比x轴的方差大,按y轴的方向进行分割,先对y轴坐标进行排序,找出划分点. 查找后得到划分点是(5,4), 并且垂直于y = 4
此时已经把左边点分开了,y轴比4小的为于直线下方,比4大的位于直线上方。
4)用同样的办法划分左子树的节点{(2,3),(5,4),(4,7)}和右子树的节点{(9,6),(8,1)}。最终得到KD树。
上面的过程用树结构(排序树)来表示就是:
二. KD树搜索最近邻
生成KD树以后,就可以去预测测试集里面的样本目标点了。对于一个目标点,首先在KD树里面找到包含目标点的叶子节点。以目标点为圆心,以目标点到叶子节点样本实例的距离为半径,得到一个超球体,最近邻的点一定在这个超球体内部。
用建立的KD树,来看对点(2, 4.5)找最近邻的过程。
(1) 首先在KD树里面找到包含目标点的叶子节点,从根结点开始遍历,找到叶子结点(4, 7) :
然后以目标点为圆心,叶子结点(4,7)到目标点的距离为半径画圆,然后回溯到父结点(5,4). 这里发现(5,4 )到目标点的距离比(4,7)到目标点的距离近,所以我们直接以(2,4,5) 为圆心,(5,4)为半径画圆。
从图中发现在圈内还有一个点(2,3),那么现在比较发现该点到目标点的距离比(5,4)到目标点的距离还小,那么接下来以(2,3)到目标点的距离为半径,(2, 4.5)为圆心画圆。
发现圈内再没有其它点,搜索路径回溯完,返回最近邻点(2,3),最近距离1.5。
三. KD树预测
有了KD树搜索最近邻的办法,KD树的预测就很简单了,在KD树搜索最近邻的基础上,我们选择到了第一个最近邻样本,就把它置为已选。在第二轮中,我们忽略置为已选的样本,重新选择最近邻,这样跑k次,就得到了目标的K个最近邻,然后根据多数表决法,如果是KNN分类,预测为K个最近邻里面有最多类别数的类别。如果是KNN回归,用K个最近邻样本输出的平均值作为回归预测值。
四. sklearn 实现KD树
# 基础结构.py
#
import numpy as np
from sklearn import linear_model, svm, neighbors, datasets, preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.model_selection import cross_val_score
# 关闭报警
import warnings
warnings.filterwarnings("ignore")
np.random.RandomState(0)
# 加载数据
iris = datasets.load_iris()
x, y = iris.data, iris.target
# 划分训练集与测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)
# 数据预处理
scaler = preprocessing.StandardScaler().fit(x_train)
x_train = scaler.transform(x_train)
x_test = scaler.transform(x_test)
# 创建模型
clf = neighbors.KNeighborsClassifier(n_neighbors=12,algorithm='kd_tree')
# clf = linear_model.SGDClassifier()
# clf = linear_model.LogisticRegression()
# clf = svm.SVC(kernel='rbf')
# 模型拟合
clf.fit(x_train, y_train)
# 预测
y_pred = clf.predict(x_test)
# 评估
print(accuracy_score(y_test, y_pred))
# f1_score
print(f1_score(y_test, y_pred, average='micro'))
# 分类报告
print(classification_report(y_test, y_pred))
# 混淆矩阵
print(confusion_matrix(y_test, y_pred))