KNN算法 数字识别

import os
import numpy as np


def data_trans(dir_path):
    file_list = os.listdir(dir_path)
    # print(file_list)
    # 声明一个大数组 存放所有文件的数组和标签
    big_arr = np.zeros((len(file_list), 1025))
    for i, file in enumerate(file_list):
        # 标签
        flag = file[0]
        # print(flag)
        # 拼接文件路径
        file_path = dir_path + '/' + file
        # print(file_path)
        # 读取文件  一维
        file_arr = np.loadtxt(file_path, dtype=str)
        # print(file_arr)
        # 用来存放每个文件中的数组
        arr = np.zeros((32, 32))
        for j, num in enumerate(file_arr):
            arr[j] = list(map(int, num))
        # print(arr)
        # 将arr展平成1*1024
        arr_ravel = arr.ravel()
        # print(arr_ravel)
        big_arr[i, 0:-1] = arr_ravel
        # 最后一列用来存放标签
        big_arr[i, -1] = flag
        # break
    name = dir_path.split('/')[-1]
    # print(big_arr.shape)
    np.savetxt("{}.csv".format(name), big_arr, fmt='%d')


if __name__ == '__main__':
    dir_path1 = "./digits/trainingDigits"
    dir_path2 = "./digits/testDigits"
    data_trans(dir_path1)
    data_trans(dir_path2)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def knn(train_digit, test_digit, k):
    true_num = 1
    # 求测试集的每一行的相似度
    for i in range(test_digit.shape[0]):
        d = np.sqrt(((test_digit[i, :-1]-train_digit[:, :-1])**2).sum(axis=1))
        sort_index = d.argsort()[:k]
        flag = train_digit[sort_index, -1]
        df = pd.DataFrame(flag).mode()
        print("预测值: ", df[0][0])
        print("真实值: ", test_digit[i, -1])
        if df[0][0] == test_digit[i, -1]:
            true_num += 1
    print('准确度为:', true_num/test_digit.shape[0])
    return true_num/test_digit.shape[0]


if __name__ == '__main__':
    train_digit = np.loadtxt("./trainingdigits.csv")
    test_digit = np.loadtxt("./testdigits.csv")
    y = []
    for k in range(5, 15):
        prec = knn(train_digit, test_digit, k)
        y.append(prec)
    x = range(5, 15)

    print(x)
    print(y)
    plt.figure()
    plt.plot(x, y, marker='*', markersize=12)
    plt.xlabel('k')
    plt.ylabel('precision')
    plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值