一.KNN算法的介绍
1.1 KNN算法概述
KNN(K-Nearest Neighbor)算法是机器学习算法中最基础、最简单的算法之一,是一种分类和回归的统计方法,是监督学习。KNN通过测量不同特征值之间的距离来进行分类。所谓k近邻,就是k个最近的邻居的意思,说的是每个样本类别都可以用它最接近的k个邻居的类别来代表。 就比如:判断一个人的人品好坏,只需要观察与他来往最密切的几个人的人品好坏就可以得出,即“近朱者赤,近墨者黑"。
原理
以所有已知类别的样本作为参照,计算未知样本与所有已知样本的距离,从中选取与未知样本距离最近的K个已知样本,根据少数服从多数的投票法则(majority-voting),将未知样本与K个最邻近样本中所属类别占比较多的归为一类。
如下图所示,如何判断绿色圆应该属于哪一类,是属于红色三角形还是属于蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被判定为属于红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆将被判定为属于蓝色四方形类。
1.2 KNN算法的一般流程
1、收集数据:可以使用任何方法。
2、准备数据:距离计算所需要的数值,最好是结构化的数据格式。
3、分析数据:可以使用任何方法。
4、测试算法:计算错误率。
5、使用算法:
5.1. 计算预测数据与训练数据之间的距离
5.2. 将距离进行递增排序
5.3. 选择距离最小的前K个数据
5.4. 确定前K个数据的类别,及其出现频率
5.5. 返回前K个数据中频率最高的类别(预测结果)
1.3 KNN算法的距离计算
KNN算法中对于距离的计算有好几种度量方式,比如欧式距离、曼哈顿距离、切比雪夫距离等等,最常用的就是欧式距离。
欧式距离计算公式:
1.4 KNN算法的K值选择
K值的大小对算法的影响:
K值太大,会导致预测标签比较稳定,可能过平滑,容易欠拟合。
K值太小,会导致预测的标签比较容易受到样本的影响,容易过拟合。
所以对于K值的选取,我们通常使用交叉验证来验证,交叉验证:将样本数据按照一定比例,拆分出训练用的数据和验证用的数据,比如6:4拆分出部分训练数据和验证数据,从选取一个较小的K值开始,不断增加K的值,然后计算验证集合的方差,最终找到一个比较合适的K值。
二.算法实现
主要有以下三个步骤:
算距离:给定待分类样本,计算它与已分类样本中的每个样本的距离;
找邻居:圈定与待分类样本距离最近的K个已分类样本,作为待分类样本的近邻;
做分类:根据这K个近邻中的大部分样本所属的类别来决定待分类样本该属于哪个分类
import math
import csv
import operator
import random
import numpy as np
from sklearn.datasets import make_blobs
#Python version 3.6.5
# 生成样本数据集 samples(样本数量) features(特征向量的维度) centers(类别个数)
def createDataSet(samples=100, features=2, centers=2):
return make_blobs(n_samples=samples, n_features=features, centers=centers, cluster_std=1.0, random_state=8)
# 加载鸢尾花卉数据集 filename(数据集文件存放路径)
def loadIrisDataset(filename):
with open(filename, 'rt') as csvfile:
lines = csv.reader(csvfile)
dataset = list(lines)
for x in range(len(dataset)):
for y in range(4):
dataset[x][y] = float(dataset[x][y])
return dataset
# 拆分数据集 dataset(要拆分的数据集) split(训练集所占比例) trainingSet(训练集) testSet(测试集)
def splitDataSet(dataSet, split, trainingSet=[], testSet=[]):
for x in range(len(dataSet)):
if random.random() <= split:
trainingSet.append(dataSet[x])
else:
testSet.append(dataSet[x])
# 计算欧氏距离
def euclideanDistance(instance1, instance2, length):
distance = 0
for x in range(length):
distance += pow((instance1[x] - instance2[x]), 2)
return math.sqrt(distance)
# 选取距离最近的K个实例
def getNeighbors(trainingSet, testInstance, k):
distances = []
length = len(testInstance) - 1
for x in range(len(trainingSet)):
dist = euclideanDistance(testInstance, trainingSet[x], length)
distances.append((trainingSet[x], dist))
distances.sort(key=operator.itemgetter(1))
neighbors = []
for x in range(k):
neighbors.append(distances[x][0])
return neighbors
# 获取距离最近的K个实例中占比例较大的分类
def getResponse(neighbors):
classVotes = {}
for x in range(len(neighbors)):
response = neighbors[x][-1]
if response in classVotes:
classVotes[response] += 1
else:
classVotes[response] = 1
sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)
return sortedVotes[0][0]
# 计算准确率
def getAccuracy(testSet, predictions):
correct = 0
for x in range(len(testSet)):
if testSet[x][-1] == predictions[x]:
correct += 1
return (correct / float(len(testSet))) * 100.0
def main():
# 使用自定义创建的数据集进行分类
# x,y = createDataSet(features=2)
# dataSet= np.c_[x,y]
# 使用鸢尾花卉数据集进行分类
dataSet = loadIrisDataset(r'C:\DevTolls\eclipse-pureh2b\python\DeepLearning\KNN\iris_dataset.txt')
print(dataSet)
trainingSet = []
testSet = []
splitDataSet(dataSet, 0.75, trainingSet, testSet)
print('Train set:' + repr(len(trainingSet)))
print('Test set:' + repr(len(testSet)))
predictions = []
k = 7
for x in range(len(testSet)):
neighbors = getNeighbors(trainingSet, testSet[x], k)
result = getResponse(neighbors)
predictions.append(result)
print('>predicted=' + repr(result) + ',actual=' + repr(testSet[x][-1]))
accuracy = getAccuracy(testSet, predictions)
print('Accuracy: ' + repr(accuracy) + '%')
main()
三、总结
3.1KNN算法的优点
1. 简单易实现:KNN算法的原理简单,易于理解和实现,无需训练过程
2.适用于多分类问题:KNN算法可以应用于多分类问题,并且对样本分布的假设较少。
3.对异常值不敏感:KNN算法对异常值不敏感,因为它是通过距离计算来确定最近邻样本,即使某个样本是异常值,也不会对整体结果产生很大影响
3.2KNN算法的缺点
1.计算复杂度高:KNN算法需要计算测试样本与所有训练样本之间的距离,计算复杂度较高,尤其是在大规模数据集上。
2.高度数据相关:KNN算法依赖于特征空间中的距离度量,如果特征空间中的距离度量不合理或者特征权重不准确,可能会导致预测性能下降。
3.决策边界不规则:KNN算法的决策边界通常是不规则的,因为它只考虑了局部样本的信息,而没有对全局进行建模。
4.对内存要求较高,因为该算法存储了所有训练数据。
5.预测阶段可能很慢:如果数据样本过多的话,需要全部遍历就会耗时间。