手撸KNN算法(python语言)

KNN算法手写代码实现,距离度量使用的欧氏距离,采用了矩阵运算,运行速度较快。
闲言碎语不要讲,上代码。

# 导入三方库
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from scipy import stats

# 加载数据
df = pd.read_csv('./data/mushrooms.csv')

X = df.drop(['class'], axis=1)
Y = df['class']
# 数据预处理
X = pd.get_dummies(X,prefix_sep='_')
# 类型转换为int型
for col in X.columns:
    X[col] = X[col].astype(int)
# 目标特征编码
lb = LabelEncoder()
Y = lb.fit_transform(Y)

#定义KNN类
class KNN:
    def __init__(self, k=3):
        self.k = k

    def fit(self, X, y):
        X = np.asarray(X, dtype='float64')
        y = np.asarray(y, dtype='float64')
        self.X_train = X
        self.y_train = y

    def predict(self, X):
        X = np.asarray(X, dtype='float64')
        y_pred = self._predict(X)
        return np.array(y_pred)
        
    def _predict(self, X):
    	# 计算欧氏距离
        # XX 为一列,维度等于X的行数
        XX = np.einsum("ij,ij->i", X, X)[:, np.newaxis]
        # YY 为一行,维度等于Y的行数
        YY = np.einsum("ij,ij->i", self.X_train, self.X_train)[np.newaxis, :]
        distance = -2 * (X @ self.X_train.T)
        distance += XX
        distance += YY
        distance = np.sqrt(distance)
        # 获取距离最近的k个样本标签
        k_indices = np.argpartition(distance, self.k - 1, axis=1)[:, :self.k] # 此处使用 argpartition 快排出前k个最近距离,执行速度比 argsort 对所有数据进行排序要快一倍。
        k_nearest_labels = self.y_train[k_indices]
        mode, _ = stats.mode(k_nearest_labels, axis=1)
        mode = np.asarray(mode.ravel(), dtype=np.intp)
        return mode
# 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=2023)
knn = KNN(k=5)
knn.fit(X_train, y_train)
pred = knn.predict(X_test)
print(confusion_matrix(y_test, pred))
print(classification_report(y_test, pred))

执行输出:

[[841   0]
 [  0 784]]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       841
           1       1.00      1.00      1.00       784

    accuracy                           1.00      1625
   macro avg       1.00      1.00      1.00      1625
weighted avg       1.00      1.00      1.00      1625
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值