kNN算法是一种分类算法,对于给定标签的训练集,计算新数据与训练集示例的距离,统计最近的k个示例,如果多数属于某一类,则新示例属于该类。
1.自定义函数实现kNN算法
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets
准备数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=100)
定义kNN类
class kNN():
'''
k:近邻个数
X_train:训练集
y_train:训练标签
x:新示例
'''
def __init__(self, k, X_train, y_train, x):
self.k = k
self.X = X_train
self.y = y_train
self.x = x
#计算未分类样本与已知类别样本的欧式距离
def distance(self):
from numpy import sqrt,sum
distances = []
for x_train in self.X:
d = sqrt(sum((self.x - x_train)**2))
distances.append(d)
return distances
#对距离进行排序,选出最近的k个示例
def sort(self):
from numpy import argsort
distances = kNN.distance(self)
d_sorted = argsort(distances)[:self.k]
return self.y[d_sorted]
#找到最多的类别
def knn_classify(self):
y_knn = kNN.sort(self)
from collections import Counter
y_knn = kNN.sort(self)
votes = Counter(y_knn)
return votes.most_common(1)[0][0]
实例化
print('预测值\t实际值')
for x, y_ in zip(X_test, y_test):
knn = kNN(6, X_train, y_train, x)
yhat = knn.knn_classify()
print(f'{yhat}\t{y_}')
预测值 实际值
2 2
0 0
2 2
0 0
2 2
2 2
0 0
0 0
2 2
0 0
0 0
2 2
0 0
0 0
2 2
2.scikit-learn中的kNN算法
准备数据
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
划分训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100)
数据归一化
from sklearn.preprocessing import StandardScaler
standardscaler = StandardScaler()
standardscaler.fit(X_train)
X_train_s = standardscaler.transform(X_train)
X_test_s = standardscaler.transform(X_test)
预测
from sklearn.neighbors import KNeighborsClassifier
kNN_classify = KNeighborsClassifier(6)
kNN_classify.fit(X_train_s, y_train)
yhat = kNN_classify.predict(X_test_s)
print(f'预测值:{yhat}')
print(f'实际值:{y_test}')
print(f'准确率:{kNN_classify.score(X_test_s, y_test)}')
预测值:[2 0 2 0 1 2 0 0 2 0 0 2 0 0 2 1 1 1 2 2 2 0 2 0 1 2 1 0 1 2]
实际值:[2 0 2 0 2 2 0 0 2 0 0 2 0 0 2 1 1 1 2 2 2 0 2 0 1 2 1 0 1 2]
准确率:0.9666666666666667