数据集来自MNIST。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import input_data
import numpy as np
import tensorflow as tf
if __name__ == '__main__':
mnist = input_data.read_data_sets('/tmp/data', one_hot=True)
# 不导入所有数据
Xtr, Ytr = mnist.train.next_batch(5000)
Xte, Yte = mnist.test.next_batch(200)
print(Xtr.shape) # (5000, 784)
print(Ytr.shape) # (5000, 10)
print(Xte.shape) # (200, 784)
print(Yte.shape) # (200, 10)
# 下面定义数据流图
xtr = tf.placeholder("float", [None, 784]) # None代表第一维长度任意
xte = tf.placeholder("float", [784])
# 计算L1Distance,定义一个施加的数学操作,从第二维这个维度进行求和
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
# 求出最小元素的下标(距离最近)
pred = tf.arg_min(distance, dimension=0)
accuracy = 0.
# 变量初始化
init = tf.global_variables_initializer()
# 下面运行数据流图
with tf.Session() as sess:
# 运行初始化
sess.run(init)
# 遍历测试数据
for i in range(len(Xte)):
# 得到最近邻居在训练集中的下标
nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})
print("Test ", i, " Prediction: ", np.argmax(Ytr[nn_index]), " True Class: ", np.argmax(Yte[i]))
if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
accuracy += 1. / len(Xte)
print("The accuracy is: ", accuracy)
'''
output:
Epoch: 0050 cost= 0.11877864 W= 0.36398676 b= -0.021470055
Epoch: 0100 cost= 0.11394625 W= 0.3571833 b= 0.02747381
Epoch: 0150 cost= 0.10967217 W= 0.3507844 b= 0.07350685
Epoch: 0200 cost= 0.10589187 W= 0.34476617 b= 0.11680212
Epoch: 0250 cost= 0.10254839 W= 0.33910576 b= 0.15752237
Epoch: 0300 cost= 0.09959126 W= 0.33378205 b= 0.19582067
Epoch: 0350 cost= 0.09697587 W= 0.32877496 b= 0.23184128
Epoch: 0400 cost= 0.09466273 W= 0.32406574 b= 0.26571926
Epoch: 0450 cost= 0.092616975 W= 0.31963646 b= 0.29758275
Epoch: 0500 cost= 0.09080763 W= 0.3154707 b= 0.32755125
Epoch: 0550 cost= 0.08920752 W= 0.3115527 b= 0.35573712
Epoch: 0600 cost= 0.08779238 W= 0.30786765 b= 0.382247
Epoch: 0650 cost= 0.08654087 W= 0.30440176 b= 0.40718022
Epoch: 0700 cost= 0.08543408 W= 0.30114213 b= 0.43063
Epoch: 0750 cost= 0.08445534 W= 0.29807636 b= 0.45268512
Epoch: 0800 cost= 0.08358977 W= 0.29519278 b= 0.4734293
Epoch: 0850 cost= 0.082824335 W= 0.29248077 b= 0.4929396
Epoch: 0900 cost= 0.08214753 W= 0.2899301 b= 0.5112887
Epoch: 0950 cost= 0.081548996 W= 0.28753105 b= 0.52854717
Epoch: 1000 cost= 0.08101973 W= 0.2852747 b= 0.5447794
Optimization finished!
Training cost= 0.00049043633 W= 0.2852747 b= 0.5447794
Testing:
Test cost = 0.0765191
'''