KNN学习笔记

'''

已知4点,分为两类,现新插入一点。用KNN来进行分类预测。

'''

 

from numpy import *

 

# 给出训练数据以及对应的类别

def createDataSet():

    group = array([[1.0, 2.0], [1.2, 0.1], [0.1, 1.4], [0.3, 3.5]])

    labels = ['A', 'A', 'B', 'B']

    return group, labels

 

# 计算欧氏距离

def euclideanMetric(input, dataSet):

    # print(dataSet)

    '''

    [[1.  2. ]

    [1.2 0.1]

    [0.1 1.4]

    [0.3 3.5]]

    '''

    dataSize = dataSet.shape[0]                # shape[0]就是读取矩阵第一维度的长度

    # print(dataSize)               # 4

    # 计算欧氏距离

    '''

    tile的作用是将矩阵进行复制。tile中有两个参数(A,B)

    A代表将行复制的次数,B代表将列复制的次数

    '''

    diff = tile(input, (dataSize, 1)) - dataSet

    # print(diff)

    '''

     [[ 0.1 -1.7]

      [-0.1  0.2]

      [ 1.  -1.1]

      [ 0.8 -3.2]]

    '''

    sqdiff = diff ** 2

    # print("sqdiff的输出结果为:", sqdiff)

    '''sqdiff的输出结果为:

    [[1.000e-02 2.890e+00]

     [1.000e-02 4.000e-02]

     [1.000e+00 1.210e+00]

     [6.400e-01 1.024e+01]]

    '''

    squareDist = sum(sqdiff, axis=1)     # axis=1:行向量分别相加,从而得到新的一个行向量

    dist = squareDist ** 0.5

    # 对距离进行排序

    sortedDistIndex = argsort(dist)      # argsort()根据元素的值从大到小对元素进行排序,返回下标

    # print(sortedDistIndex)                 # [1 2 0 3]

    return sortedDistIndex

 

# 通过KNN进行分类

def classify(input, dataSet, label, k):

    # 计算欧氏距离

    sortedDistIndex = euclideanMetric(input, dataSet)

    # 对选取的k个样本所属的类别个数进行统计

    classCount = {}

    for i in range(k):

        voteLabel = label[sortedDistIndex[i]]

        # print(voteLabel)

        '''输出分别为:

        A

        B

        A

        '''

        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1

        # print(classCount.items())

        '''输出分别为:

        dict_items([('A', 1)])

        dict_items([('A', 1), ('B', 1)])

        dict_items([('A', 2), ('B', 1)])

        '''

    # 选取出现的类别次数最多的类别

    maxCount = 0

    for key, value in classCount.items():

        if value > maxCount:

            maxCount = value

            classes = key

    return classes

 

if __name__ == "__main__":

    dataSet, labels = createDataSet()

    input = array([1.1, 0.3])

    K = 3

    output = classify(input, dataSet, labels, K)

    print("测试数据为:", input, "分类结果为:", output)

 


输出结果为:

测试数据为: [1.1 0.3] 分类结果为: A

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值