机器学习之knn-邻近算法

一,简介

​ knn邻近算法,或者说K最近邻(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法 。

二,算法原理

​ 通过欧式距离公式计算两个向量点 xA、xB之间的距离:
d = ( x A 0 − x B 0 ) 2 + ( x A 1 − x B 1 ) 2 d=\sqrt{(xA_0-xB_0)^2+(xA_1-xB_1)^2} d=(xA0xB0)2+(xA1xB1)2
例如,点(0,0)与(1,2)之间的距离计算为
( 1 − 0 ) 2 + ( 2 − 0 ) 2 \sqrt{(1-0)^2+(2-0)^2} (10)2+(20)2
如果存在多个特征值则以此类推。

​ KNN算法就是利用测试样本数据与训练集中的每个数据计算距离,然后通过计算得出的距离排序,得出距离最近的n个数据,然后统计n个中各个标签的数量,数量最多的标签就是测试样本数据的结果(标签),因此n的取值很重要。

三,实例

​ 本次算法利用手写体识别作为样例来演示本算法。废话不多说,贴代码:

# 导入numpy库并且设置缩写为np
import numpy as np
# 导入os库
import os


# 定义获取单个文件数据的函数
def get_data(name):
    # 打开文件
    f1 = open(name, 'r')
    # 设置一个一维array来存储数据
    data = np.zeros((1, 1024))
    # 遍历文件来将数据写入array
    for i in range(32):
        # 读取每一行数据
        now_data = f1.readline()
        # 读取每行的每个数据
        for j in range(32):
            # 将数据写入array
            data[0][32 * i + j] = now_data[j]
    # 返回每一个文本的array
    return data

# 存下训练集的文件名
data_dir = os.listdir("trainingDigits")
# 存下训练集的数据
train_datas = []
# 遍历每个文件,存下训练数据
for i in data_dir:
    # 拼接文件名
    now_name = "trainingDigits/" + str(i)
    # 分隔标签名
    name_count = i.split("_")[0]
    # 将数据和标签存入array
    train_datas.append([get_data(now_name), name_count])

# 获取测试集的文件名
test_data_dir = os.listdir("testDigits")
# 用于存下测试集的数据
test_datas = []
# 遍历每个文件,存入测试数据
for j in test_data_dir:
    # 拼接文件名
    now_name = "testDigits/" + str(j)
    # 存入数据和文件名
    test_datas.append([get_data(now_name), j])


# 用于计算距离
def get_juli(test_data, train_data):
    # 计算距离
    res = sum((test_data[0] - train_data[0]) ** 2) ** 0.5
    # 返回结果
    return res


# 用于判断给测试数据打上标签
def run_test_data(test_data, k):
    # 存下测试数据和训练集每个数据的距离
    end_data = []
    # 遍历测试数据
    for i in train_datas:
        # 获取训练数据的单个数据集
        train_data = i[0]
        # 获取数据集的标签
        label = i[1]
        # 计算距离
        number = get_juli(test_data, train_data)
        # 距离和标签一起添加到结果列表中
        end_data.append((number, label))
    # 数据排序
    end_data.sort()
    # 统计标签
    label_count = {}
    # 遍历结果数据
    for i in range(k):
        # 获取标签
        label = end_data[i][1]
        # 计数标签
        if label in label_count:
            # 标签存在统计数+1
            label_count[label] += 1
        else:
            # 标签不存在,统计数为1
            label_count[label] = 1
    # 排序结果
    res_labels = sorted(label_count.items(), key=lambda x: x[1], reverse=True)
    # 返回最大的可能性(也就是测试集的标签)
    return res_labels[0][0]


# 统计总数
all_count = len(test_datas)
# 统计错误数
error_count = 0
# 统计错误文件
error_ls = []
# 打开一个文件,用于写下结果
res_txt = open("res.txt", 'w', encoding='utf-8')
# 遍历测试集
for test in test_datas:
    # 获取测试的数据
    now_test_data = test[0]
    # 获取测试集的名字
    now_test_name = test[1]
    # 获取测试集的真实标签
    correct_number = now_test_name.split("_")[0]
    # 训练测试集
    res = run_test_data(now_test_data, 10)
    # 用于判断是否正确
    if res != correct_number:
        # 错误,则统计数+1
        error_count += 1
        # 统计错误的测试集
        error_ls.append(now_test_name)
    # 写入文件
    res_txt.write("{}的结果是{}".format(now_test_name, res) + '\n')
    # 打印结果
    print("{}的结果是{}".format(now_test_name, res))
# 打印最终结果
print(f"错误的有{error_count}个, 错误率是{error_count/all_count}, 是{error_ls}")
# 关闭文件
res_txt.close()

结果展示训练结果
写入文本

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值