TensorFlow_NearestNeighbour实现

数据集来自MNIST

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import input_data

import numpy as np
import tensorflow as tf

if __name__ == '__main__':
    mnist = input_data.read_data_sets('/tmp/data', one_hot=True)

    # 不导入所有数据
    Xtr, Ytr = mnist.train.next_batch(5000)
    Xte, Yte = mnist.test.next_batch(200)

    print(Xtr.shape)  # (5000, 784)
    print(Ytr.shape)  # (5000, 10)
    print(Xte.shape)  # (200, 784)
    print(Yte.shape)  # (200, 10)

    # 下面定义数据流图
    xtr = tf.placeholder("float", [None, 784])  # None代表第一维长度任意
    xte = tf.placeholder("float", [784])

    # 计算L1Distance,定义一个施加的数学操作,从第二维这个维度进行求和
    distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)

    # 求出最小元素的下标(距离最近)
    pred = tf.arg_min(distance, dimension=0)

    accuracy = 0.

    # 变量初始化
    init = tf.global_variables_initializer()

    # 下面运行数据流图
    with tf.Session() as sess:
        # 运行初始化
        sess.run(init)

        # 遍历测试数据
        for i in range(len(Xte)):
            # 得到最近邻居在训练集中的下标
            nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})
            print("Test ", i, " Prediction: ", np.argmax(Ytr[nn_index]), " True Class: ", np.argmax(Yte[i]))

            if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
                accuracy += 1. / len(Xte)

    print("The accuracy is: ", accuracy)


'''
output:
Epoch: 0050 cost= 0.11877864 W= 0.36398676 b= -0.021470055
Epoch: 0100 cost= 0.11394625 W= 0.3571833 b= 0.02747381
Epoch: 0150 cost= 0.10967217 W= 0.3507844 b= 0.07350685
Epoch: 0200 cost= 0.10589187 W= 0.34476617 b= 0.11680212
Epoch: 0250 cost= 0.10254839 W= 0.33910576 b= 0.15752237
Epoch: 0300 cost= 0.09959126 W= 0.33378205 b= 0.19582067
Epoch: 0350 cost= 0.09697587 W= 0.32877496 b= 0.23184128
Epoch: 0400 cost= 0.09466273 W= 0.32406574 b= 0.26571926
Epoch: 0450 cost= 0.092616975 W= 0.31963646 b= 0.29758275
Epoch: 0500 cost= 0.09080763 W= 0.3154707 b= 0.32755125
Epoch: 0550 cost= 0.08920752 W= 0.3115527 b= 0.35573712
Epoch: 0600 cost= 0.08779238 W= 0.30786765 b= 0.382247
Epoch: 0650 cost= 0.08654087 W= 0.30440176 b= 0.40718022
Epoch: 0700 cost= 0.08543408 W= 0.30114213 b= 0.43063
Epoch: 0750 cost= 0.08445534 W= 0.29807636 b= 0.45268512
Epoch: 0800 cost= 0.08358977 W= 0.29519278 b= 0.4734293
Epoch: 0850 cost= 0.082824335 W= 0.29248077 b= 0.4929396
Epoch: 0900 cost= 0.08214753 W= 0.2899301 b= 0.5112887
Epoch: 0950 cost= 0.081548996 W= 0.28753105 b= 0.52854717
Epoch: 1000 cost= 0.08101973 W= 0.2852747 b= 0.5447794
Optimization finished!
Training cost= 0.00049043633 W= 0.2852747 b= 0.5447794
Testing:
Test cost = 0.0765191
'''

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值