Scikit-learn-03.K-近邻算法

本系列文章介绍人工智能的基础概念和常用公式。由于协及内容所需的数学知识要求,建议初二以上同学学习。 运行本系统程序,请在电脑安装好Python、matplotlib和scikit-learn库。相关安装方法可自行在百度查找。

目录

算法优缺点

算法说明

示例程序


K-近邻算法(KNN)是有监督的机器学习算法。它的核心是未标记样本的类别,由距离其最近的K个邻居投票来决定。 假设X_test为待标记的数据样本,X_train为标记的数据集,计算过程如下:

  1. 遍历X_train中的所有样本,计算每个样本与X_test的距离,并把距离保存在Distance数组中。
  2. 对Distance数组进行排序,取距离最近的k的点,记为X_knn。
  3. 在X_knn中统计每个类别的个数,即Class0在X_knn中有几个样本,Class1在X_knn中有几个样本等。
  4. 待标记样本的类别,就是在X_knn中样本个数最多的那个类别。

算法优缺点

优点:准确性高,对异常值和噪声有较高的容忍度。

缺点:计算量较大,对内存需求高。


算法说明


示例程序

from sklearn.datasets.samples_generator import make_blobs
from sklearn.neighbors import KNeighborsClassifier
from matplotlib import pyplot as plt
import numpy as np
   
centers = [ [-2,2],[2,2],[0,4] ]
   
K = 5
   
X,Y = make_blobs(n_samples=60,centers=centers,random_state=0,cluster_std=0.6)
   
clf = KNeighborsClassifier(n_neighbors=K)
clf.fit(X,Y)
   
X_sample = [0,2]
XX_sample = np.array([0,2]).reshape(1,-1)
Y_sample = clf.predict(XX_sample)
neighbors = clf.kneighbors(XX_sample,return_distance=False)
   
plt.figure(figsize=(16,10),dpi=144)
plt.title("KNN")
   
c = np.array(centers)
plt.scatter(X[:,0],X[:,1],c=Y,s=100,cmap='cool')
plt.scatter(c[:,0],c[:,1],s=100,marker='^',c='green')
plt.scatter(X_sample[0],X_sample[1],marker='x',s=100,c='red',cmap='cool')
   
for i in neighbors[0]:
    plt.plot([X[i][0],X_sample[0]],[X[i][1],X_sample[1]],'k--',linewidth=0.5)
   
plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值