tensorflow-实现knn算法-识别mnist数据集

概述

Mnist数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。每一张图片包含28像素X28像素的灰度图片。
我们要做的是对测试数据集的每一个数据从训练数据集中找出最临近的类别,进行预测。
那么如何找最临近的(距离),我们通过计算L1距离或L2距离来计算最临近。

实现步骤

  • 1.获取mnist数据
  • 2.计算距离(L1或L2)
  • 3.获取最小距离得索引(计算准确度用)
  • 4.开启会话,初始化变量op
  • 5.循环对每一个测试数据分别查找和训练数据集的最小距离,得出索引。
  • 6.计算准确度(预测值和真实值相等的数据/所有测试数据)

实现代码

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np


def knn_tensorflow():
    """ tensorflow实现knn算法,对mnist数据识别分类
    :return None
    """
    mnist = input_data.read_data_sets("./data/mnist/", one_hot=True)

    # 数据全部取出, 普通pc计算需要半小时左右,如果嫌太慢,可以少取一些数据。
    train_x, train_y = mnist.train.next_batch(60000)
    test_x, test_y = mnist.test.next_batch(10000)

    # 占位符
    train_x_p = tf.placeholder(tf.float32, [None, 784])
    test_x_p = tf.placeholder(tf.float32, [784])

    # L1距离计算:dist = sum(|X1-X2|)
    #dist_l1 = tf.reduce_sum(tf.abs(train_x_p + tf.negative(test_x_p)), reduction_indices=1)

    # L2距离计算:dist = sqrt(sum(|X1-X2|^2))
    dist_l2 = tf.sqrt(tf.reduce_sum(tf.square(tf.abs(train_x_p + tf.negative(test_x_p))), reduction_indices=1))

    # 获得最小距离的索引
    prediction = tf.arg_min(dist_l2, 0)

    # 定义准确率
    accuracy = 0.

    init_op = tf.initialize_all_variables()

    with tf.Session() as sess:
        sess.run(init_op)

        for i in range(len(test_x)):
            # 获取最近邻的值得索引
            nn_index = sess.run(prediction, feed_dict={train_x_p: train_x, test_x_p: test_x[i, :]})
            print("测试集第 %d 条,实际值:%d,预测值:%d" % (i, np.argmax(test_y[i]), np.argmax(train_y[nn_index])))

            # 当预测值==真实值时,计算准确率。
            if np.argmax(test_y[i]) == np.argmax(train_y[nn_index]):
                accuracy += 1. / len(test_x)

        print("准确率:%f " % accuracy)

    return None

if __name__ == '__main__':
    knn_tensorflow()

L1距离输出:
在这里插入图片描述
L2距离输出:
在这里插入图片描述
以上是knn算法的tensorflow实现,但大家可以看出,并没有k值的调整,以上默认k值是1,那假设k值是3或者5的时候会不会更加准确,可以自己尝试下。

注:tf.reduce_sum()方法中有个reduction_indices参数,表示函数的处理纬度,默认为None,即会把input_tensor降到0维,即1个数值。如下图所示(图片来自网络):
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值