KNN的三要素
KNN实现方式
KD Tree的构建
在数据量少的时候我们可以通过暴力的方式来构建,但是数据量大的时候,这样做的计算量非常大,所以我们可以通过KD Tree的方式来进行计算。
- 计算所有特征的方差,选取方差大的特征来进行划分,选取这个特征的中位数作为分割点,小于这个中位数的划分到左子树,大于这个中位数的划分到右子树。
- 然后对于左右两颗子树继续上面的方法进行划分。
上图数据中样本有两维。可以看出第一维的方差较大,我们选择第一位的中位数来作为划分节点。将第一维排序, (2,3) (4,7) (5,4) (7,2) (8,1) (9,6)。
选择(5,4) ,(7,2) 作为分割点都可以,这里我们选择(7,2)。
左子树(2,3) (4,7) (5,4),右子树 (8,1) (9,6)。此刻,左子树的第二个维度方差较大,排序(2,3) (5,4) (4,7) ,选择(5,4)作为分割点,后面省略
KD Tree寻找最近邻
- 找到目标样本在KD Tree中的叶子节点,以目标节点为圆心,到根节点的距离为半径,得到一个超球体。
- 遍历超球体相交或者相切的线,找到最近的根节点,划分超球体,一直迭代
- 最终所有可能的切线遍历一边,找到最近邻。
代码
对标签做独热编码
# label_encoder=LabelEncoder()
# label_encoder.fit(Y)
# Y=label_encoder.transform(Y)
# print(Y)
algo=KNeighborsClassifier(n_neighbors=10)
"""
n_neighbors=5, 算法当中的k值 最近邻的几个点
weights='uniform', 分类回归规则 用来确定样本权重 uniform 所有样本等权重 distance 权重按照距离成反比
algorithm='auto', 求解方法 kd_ball brute 暴力求解
leaf_size=30, 限制叶节点的数量 可以防止过拟合
p=2, 距离度量公式中闵可夫斯基的 参数值 p=2时 欧几里得距离
metric='minkowski',
n_jobs=None, 调用几个线程来进行模型运行
"""
# 预测模型
pred_train = algo.predict(x_train)
pred_test = algo.predict(x_test)
# # 6.模型效果评估 (分类 回归 聚类)
print("训练集上的准确率:{}".format(algo.score(x_train,y_train)))
print("测试集上的准确率:{}".format(algo.score(x_test,y_test)))
print("测试集上的准确率:{}".format(f1.score(y_test,algo.predict(x_test),average='macor'))) #计算F1 默认是二分类的,需要设置参数修改为多分类
print("训练集上的准确率:{}".format(accuracy_score(y_train,pred_train)))
print("测试集上的准确率:{}".format(accuracy_score(y_test,pred_test)))
print("训练集上的混淆矩阵:\n{}".format(confusion_matrix(y_train,pred_train)))
print("测试集上的混淆矩阵:\n{}".format(confusion_matrix(y_test,pred_test)))
print("="*100)
# 分别计算每一类的auc
#
y_true = label_binarize(y_test, classes=(1,2,3))
y_score = algo.predict_proba(x_test)
# print(y_score)
# # print(y_true.shape)
# print(y_score.shape)
# 计算鸢尾花种类一的AUC的值
fpr,tpr,threads=roc_curve(y_true[:,0],y_score[:,0])
auc_score=auc(fpr,tpr)
print(auc_score)
print("="*100)
# 计算鸢尾花种类二的AUC的值
fpr,tpr,threads=roc_curve(y_true[:,1],y_score[:,1])
auc_score=auc(fpr,tpr)
print(auc_score)
print("="*100)
# 计算鸢尾花种类三的AUC的值
fpr,tpr,threads=roc_curve(y_true[:,2],y_score[:,2])
auc_score=auc(fpr,tpr)
print(auc_score)