目录
当谈到机器学习和模式识别的基本算法之一时,K最近邻(K-Nearest Neighbors,KNN)算法常常是首选。KNN是一种简单而强大的分类和回归算法,它能够通过寻找与新数据点最相似的已标记数据点来进行预测。在本篇博客中,我们将详细介绍KNN算法的工作原理、KNN实现鸢尾花分类实例,应用领域和一些关键考虑因素。
1.KNN算法的工作原理
KNN算法的核心思想非常简单:新数据点的分类或数值预测取决于其最接近的K个邻居的投票或平均值。测试数据的预测结果取决于已知数据和测试数据的距离以及人为设置的k值。如图所示,假设k设置为3,由于测试数据最相近的3个已知数据有2个红色,1个蓝色,则预测结果为红色;假设k设置为5,由于测试数据最相近的5个已知数据又3个蓝色,2个红色,则预测结果为蓝色。
2.KNN的基本工作流程:
2.1选择K值:需要选择一个整数K,它代表了要考虑的最近邻居的数量。通常,K的选择是一个关键决策,它会影响模型的性能。
2.2计算距离:对于要预测的新数据点,计算它与训练集中所有已标记数据点之间的距离。距离度量方法有欧氏距离曼哈顿距离,切比雪夫距离来计算距离,KNN算法常用欧氏距离来计算:
2.3找到最近的K个邻居:从所有训练数据点中选择距离最近的K个点。
2.4投票或平均值:对于分类问题,K个邻居中各类别的出现次数进行统计,选择出现次数最多的类别作为预测结果。对于回归问题,计算K个邻居的平均值作为预测值。
这就是KNN的基本工作原理。它是一种懒惰学习(Lazy Learning)算法,因为它在训练阶段不会构建模型,而是在预测时根据训练数据进行实时计算。
3.实验之KNN实现鸢尾花分类实例
3.1实验环境:
pytorch 1.10.0+cpu
python3.7.5
win10+vscode
安装matplotlib、numpy、pandas、seaborn、sklearn,tensorflow等库,提示缺哪个库就装哪个库,只需输入pip install xxx(库名)即可(换源:pip +库名+ -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
)
实验目录:
3.2数据集介绍:
鸢尾花数据集iris.csv:共有150个样本,[Id,Sepal.Length Sepal.Width Petal.Length Petal.Width,cClass]对应编号,萼片长度、萼片宽度、花瓣长度、花瓣宽度,类别标签。
3.3数据可视化及其分析
提取鸢尾花的任意两个特征作为二维空间的坐标点进行可视化,来观察每个类别的属性分布范围。
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import pandas as pd
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False
TRAIN_URL = r'http://download.tensorflow.org/data/iris_training.csv'
train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1],TRAIN_URL)
names = ['Sepal length','Sepal width','Petal length','Petal width','Species']
df_iris = pd.read_csv(train_path,header=0,names=names)
iris_data = df_iris.values
plt.figure(figsize=(15,15),dpi=60)
for i in range(4):
for j in range(4):
plt.subplot(4,4,i*4+j+1)
if i==0:
plt.title(names[j])
if j==0:
plt.ylabel(names[i])
if i == j:
plt.text(0.3,0.4,names[i],fontsize = 15)
continue
plt.scatter(iris_data[:,j],iris_data[:,i],c= iris_data[:,-1],cmap='brg')
plt.tight_layout(rect=[0,0,1,0.9])
plt.suptitle('鸢尾花数据集\nBule->Setosa | Red->Versicolor | Green->Virginica', fontsize = 20)
plt.show()
3.4代码编写
3.4.1步骤:
3.4.1.1定义Data()函数,用于读取数据集。
3.4.1.2定义Datasets(iris)函数,用于划分数据集。该函数随机取15个样本作为测试集,其他样本作为训练集。函数返回一个包含测试集和训练集的List。
3.4.1.3定义KNN(Train, Test, GT, k)函数,用于实现KNN算法。其中,Train是训练集,Test是单个测试样本,GT是训练集标签,k是KNN算法中的超参数。首先计算测试样本到每个训练集样本的距离并排序,然后选择前k个最近邻居,再使用训练集标签中出现次数最多的标签作为预测结果。
3.4.1.4定义cross_define_K(Train, Test, GT)函数,用于对KNN算法进行交叉验证。该函数循环遍历k从1到49的值,每次将测试样本与训练数据交叉验证,计算准确率并将其存储在List precision中。最终,函数绘制出K值与准确率之间的关系图。
3.4.2代码部分:
import numpy as np
import pandas as pd
import math
from collections import Counter
import matplotlib.pyplot as plt
# 读取数据集
def Data():
iris=pd.read_csv('iris1.csv')
return iris
# 划分数据集
def Datasets(iris):
index=np.random.permutation(len(iris))
index=index[0:15]
Test = iris.take(index)
Train = iris.drop(index)
datasets = [Test, Train]
return datasets
# KNN算法
def KNN(Train, Test, GT, k):
Train_num = Train.shape[0]
tests = np.tile(Test, (Train_num, 1)) - Train
distance = (tests ** 2) ** 0.5
result = distance.sum(axis=1)
results = result.argsort()
label = []
for i in range(k):
label.append(GT[results[i]])
return label
#交叉验证
def cross_define_K(Train, Test, GT):
precision = []
for k in range(1,50):
#print(k)
true = 0
for i in Test:
Test1 = [i[0],i[1],i[2],i[3]]
result = KNN(Train,Test1,GT,k)
collection = Counter(result)
result = collection.most_common(1)
if result[0][0] == i[4]:
true += 1
success = true / len(Test)
precision.append(success)
k1 = range(1,50)
plt.plot(k1,precision,label='line1',color='g',marker='.',markerfacecolor='pink',markersize=10)
plt.xlabel('K')
plt.ylabel('Precision')
plt.title('KNN')
plt.legend()
plt.show()
if __name__ == "__main__":
# 读取iris数据集
iris = Data()
# 对数据集进行划分(训练集,测试集)
datasets = Datasets(iris)
print(datasets[0])
# 设置KNN的k值
k = 3
# 将训练集的GT隐去
Train = datasets[1].drop(columns=['Class']).values
# 读取训练集的GT
GT = datasets[1]['Class'].values
# 读取测试集
Test = datasets[0].values
cross_define_K(Train,Test,GT)
true = 0
for i in Test:
Test = [i[0],i[1],i[2],i[3]]
result = KNN(Train,Test,GT,k)
# KNN返回的是测试数据与训练数据相近的n个预测值
collection = Counter(result)
result = collection.most_common(1)
#print(result[0][0])
# 选取其中出现最多的结果进行验证
if result[0][0] == i[4]:
true += 1
success = true/len(datasets[0])
print('success:\n',success)
3.5 sklearn库实现KNN算法
sklearn库封装好了KNN算法 可以直接使用
import sklearn.datasets as datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
iris = datasets.load_iris()
feature = iris['data']
target = iris['target']
x_train, x_test, y_train, y_test = train_test_split(feature, target, test_size=0.2, random_state=2021)
print(x_train)
knn = KNeighborsClassifier(n_neighbors=3) #
knn = knn.fit(x_train, y_train)#训练
print(knn)
y_pred = knn.predict(x_test)#预测
y_true = y_test
print('模型的分类结果:', y_pred)
print('真实的分类结果:', y_true)
print(knn.score(x_test, y_test))
test1 = knn.predict([[6.1, 3.1, 4.7, 2.1]])
print(test1)
4.KNN的应用领域
KNN算法在各种领域都有广泛的应用,包括但不限于以下几个方面:
图像识别:KNN可以用于图像分类,通过比较像素值或特征向量来确定图像所属的类别。
推荐系统:KNN可用于协同过滤推荐系统,根据用户和物品之间的相似性来推荐商品或内容。
医学诊断:KNN可以用于根据患者病例的相似性来辅助医学诊断。
文本分类:KNN可用于将文本文档分类到不同的类别,例如垃圾邮件过滤或情感分析。
异常检测:KNN可用于检测异常值,例如信用卡欺诈检测。
自然语言处理:KNN在词嵌入(word embedding)和词义相似性计算等NLP任务中也有应用。
5.考虑因素和改进
虽然KNN算法简单且易于理解,但在实际应用中需要注意一些重要因素:
5.1特征选择:选择合适的特征对KNN的性能至关重要。不相关或冗余的特征可能会导致性能下降。
5.2距离度量:选择适合问题的距离度量方法很重要。欧氏距离是常用的度量方式,但在某些情况下,需要考虑其他度量方法,如曼哈顿距离或余弦相似度。
5.3K值选择:选择适当的K值通常需要进行交叉验证或其他技术来确定。
5.4数据预处理:对数据进行归一化或标准化可以改善KNN的性能,因为KNN对特征的尺度敏感。
5.5处理不平衡数据:在处理不平衡类别的问题时,需要考虑采样方法或不同的K值。
总之,KNN算法是一个强大而灵活的机器学习工具,但在使用时需要谨慎选择参数和特征,以确保获得最佳性能。这个算法的简单性和可解释性使它成为学习机器学习的很好起点,同时也在各种实际问题中取得了成功。
6.报错解决:
报错:raceback (most recent call last):
File "d:/Machine-learning/exp/exp1-knn/knn-iris.py", line 80, in <module>
cross_define_K(Train,Test,GT)
File "d:/Machine-learning/exp/exp1-knn/knn-iris.py", line 43, in cross_define_K
result = KNN(Train,Test1,GT,k)
File "d:/Machine-learning/exp/exp1-knn/knn-iris.py", line 26, in KNN
tests = np.tile(Test, (Train_num, 1)) - Train
ValueError: operands could not be broadcast together with shapes (135,4) (135,5)
出在数据的维度不匹配。在我的 KNN 函数中,我尝试将测试集 Test 复制成与训练集 Train 相同的形状,然后计算它们之间的距离,应该是由于测试集和训练集具有不同的列数,这导致了维度不匹配的错误。
刚学不是很明白,没改啥,就换了另一个鸢尾花数据集重启一遍又好了。
7.数据集下载网站:
7.1开放数据集-飞桨AI Studio星河社区开放数据集-飞桨AI Studio星河社区开放数据集-飞桨AI Studio星河社区