Python 实现 KNN 分类算法

本文将详细讲述 KNN 算法及其 python 实现

1. KNN

KNN(K-Nearest Neighbour)即 K最近邻,是分类算法中最简单的算法之一。KNN 算法的核心思想是 如果一个样本在特征空间中的 k 个最相邻的样本中的大多数属于某一个类别,则将该样本归为该类别

1.1 KNN 分类算法步骤

有 N 个已知分类结果的样本点,对新纪录 r 使用 KNN 将其分类

  • 1.确定 k 值,确定计算距离的公式,如常用欧氏距离 d ( x , y ) = ∑ i = 1 n ( x i − y i ) 2 d(x,y)=\sqrt{\displaystyle \sum^n_{i = 1}{{(x_i-y_i)}^2}} d(x,y)=i=1n(xiyi)2
  • 2.计算 r 和其他样本点之间的距离 d i r d_{ir} dir,其中 i ∈ ( 1 , N ) i\in(1,N) i(1,N)
  • 3.得到与 r 最接近的 k 个样本
  • 4.将 k 个样本中最多归属类别的分类标签赋予新纪录 r,分类结束

1.2 KNN 的优缺点

优点:

  • 原理简单,容易理解,容易实现
  • 重新训练代价较低
  • 时间复杂度、空间复杂度取决于训练集(一般不会太大)

缺点:

  • KNN 属于 lazy-learning 算法(对于每一个新加入的预测点,都要从头开始计算与每个样本点的距离),得到的结果及时性差
  • k 值对结果影响较大
  • 不同类记录相差较大时容易误判
  • 样本点较多时,计算量较大
  • 相对于决策树,结果可解释性不强

2. python 实现

已知分类如图所示(由于是随机产生,所以具体的样本点可能不一样)

其中顺时针依次是第1、2、3类,即红色是第 1 类,蓝色是第 2 类, 灰色是第 3 类

# coding=utf-8

"""
@author: shenke
@project: AITest
@file: knn.py
@date: 2020/2/26
@description: python 实现 KNN(K-最邻近)分类算法
"""

import numpy as np
import matplotlib.pyplot as plt
from math import sqrt


class KNN():

    def __init__(self, k):
        self.k = k

    def generate_points(self, x_scope, y_scope, size):
        """
        产生给定范围内的二维坐标点
        """
        x = np.random.randint(x_scope[0], x_scope[1], size=size)
        y = np.random.randint(y_scope[0], y_scope[1], size=size)
        points = np.dstack((x, y))[0]
        return points

    def generate_data(self, size):
        """
        随机产生三个范围内的数据
        """
        points1 = self.generate_points([0, 8], [12, 20], size)
        labels1 = [1] * size
        points2 = self.generate_points([12, 20], [12, 20], size)
        labels2 = [2] * size
        points3 = self.generate_points([7, 13], [0, 8], size)
        labels3 = [3] * size

        plt.scatter(points1[:size, 0], points1[:size, 1], color='red')
        plt.scatter(points2[:size, 0], points2[:size, 1], color='blue')
        plt.scatter(points3[:size, 0], points3[:size, 1], color='gray')

        data = np.concatenate([points1, points2, points3])
        label = np.concatenate([labels1, labels2, labels3])
        return data, label

    def classify(self, target):
        """
        实现 KNN 分类
        """
        k = self.k
        # 设定每个类别中有 10 个样本点
        data, label = self.generate_data(10)

        # 计算欧氏距离
        distance = [sqrt(np.sum((target - point) ** 2)) for point in data]
        # 返回距离最近的 k 个样本的下标
        k_index = np.argsort(distance)[:k]
        # 返回 k 个样本的标签
        k_labels = [label[item] for item in k_index]
        # 返回 k 个样本中最多归属类别的分类标签
        res = max(k_labels, key=k_labels.count)

        print('该目标点为:第 %d 类' % (res))

        # 展示结果
        # 标出距离最近的 k 个样本点
        plt.scatter([data[index][0] for index in k_index], [data[index][1] for index in k_index], color='', marker='o',
                    edgecolors='green', s=200)
        # 标出目标点
        plt.scatter(target[0], target[1], color='green')
        plt.show()

测试

from algorithm import knn

if __name__ == '__main__':
    # 设定 k 值为 4,预测点坐标为(10,10)
    knn.KNN(4).classify([10, 10])

预测结果

上图中标出了预测点(绿色)并圈出了与预测点距离最近的四个点,其中属于第 3 类的样本点个数最多,故预测该点属于第 3 类

但是由于 k 值对预测结果影响较大,可能对预测结果产生误判。如以下情况,四个点中属于第 1 类和第 3 类的样本点个数一样多,这时就无法准确判断出该点的类别

  • 4
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值