K近邻算法的Python代码实现(CS231课程)

K近邻(KNN)是一种基本的机器学习方法,尽管在分类中使用较少,但其理念常在生活中体现。该算法基于三个要素:距离度量(如欧式距离)、k值选择(通过交叉验证确定最佳值)和多数表决的分类决策规则。本文将探讨如何在Python中实现KNN,并提供三种计算距离的方法,以提高效率。
摘要由CSDN通过智能技术生成

算法描述

K近邻法(k-nearest neighbors,KNN)是一种很基本的机器学习方法,实际做分类时已经用得不多了,但实际生活中还是会不自觉的使用这种算法。比如想了解一个人的人品,可以看他跟平时跟哪些人走得比较近,一般他跟他的那些朋友们的人品不会相差太远。在用KNN做分类预测时,也是同理,即训练集里和预测的样本特征最近的K个样本,预测为里面有最多类别数的类别,类似于多数投票表决。

K近邻法的三要素

1.距离度量
距离度量可以选择曼哈顿距离,欧式距离等,在大多数的情况下使用欧式距离就能满足需求,下面的算法实现中使用的是欧式距离。
2.k值的选择
k值是一个很重要的参数,不同的k值多结果影响很大,可以把值设置成一个训练参数,通常采用交叉验证法来选取最优的k值。
3.分类决策规则
k近邻法中的分类决策规则一般使用的是多数表决的办法,预测类别为k个邻近的训练集中的多数类。

在Python代码中,循环的使用会使代码效率降低,所以,要尽量减少循环的使用。以下代码中展示了三种计算距离的方式,感兴趣的读者可以自行对比这三者的运行效率。

import numpy as np


class KNearestNeighbor(object):
    """ a kNN classifier with L2 distance """

    def __init__(self):
        self.x_train = None
        self.y_train = None

    def train(self, x, y):
        """
        Train the classifier. For k-nearest neighbors this is just
        memorizing the training data.

        :param x: A numpy array of shape (num_train, D) containing the training data
                  consisting of num_train samples each of dimension D.
        :param y: A numpy array of shape (N,) containing the training labels, where
                  y[i] is the label for X[i].
        :return: None
        """
		#训练直接接收分类的训练集即可
        self.x_train = x
        self.y_train = y
    
    def predict(self, x, k=1, num_loops=0):
        """
        Predict labels for test data using this classifier.

        :param x: A numpy array of shape (num_test, D) containing test data consisting
                  of num_test samples each of dimension D.
        :param k: The number of nearest neighbors that vote for the predicted labels.
        :param num_loops: Determines which implementation to use to compute distances
                          between training points and testing points.
        :return y: A numpy array of shape (num_test,) containing predicted labels for the
                   test data, where y[i] is the predicted label for the test point X[i].
        """
        # 暴力搜索 复杂度高
        if num_loops == 0:
            dists = self.compute_distances_no_loops(x)
        elif num_loops == 1:
            dists 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值