用tensorflow实现最近邻算法,对代码进行标注解释。
'''
最邻近算法
'''
from __future__ import print_function
import numpy as np
import tensorflow as tf
#导入数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data",one_hot=True)
#限制数据集数量
Xtr,Ytr = mnist.train.next_batch(5000)#训练集5000
Xte,Yte = mnist.test.next_batch(200)
#图形输入
xtr = tf.placeholder("float",[None,784])
'''
xtr不是一个特定的值,而是一个占位符placeholder,我们在TensorFlow运行计算时输入这个值。
我们希望能够输入任意数量的MNIST图像,每一张图展平成784维的向量。
我们用2维的浮点数张量来表示这些图,这个张量的形状是[None,784 ]。
(这里的None表示此张量的第一个维度可以是任何长度的。)
'''
xte = tf.placeholder("float",[784])
#用L1距离进行最近邻计算
distance = tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)
#预测,获取最小的索引
pred = tf.arg_min(distance,0)
accuracy = 0.
#初始化变量(配置默认值)
init = tf.global_variables_initializer()
#开始训练
with tf.Session() as sess:
#开始初始化
sess.run(init)
#循环测试数据
for i in range(len(Xte)):
#得到最近邻
nn_index = sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i,:]})
#获取最近邻居类别标签并将其与其真实标签进行比较
print("Test",i,"Prediction:",np.argmax(Ytr[nn_index]),"True Class:",np.argmax(Yte[i]))
#计算准确度
if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
accuracy += 1./len(Xte)
print("Done!")
print("Accuracy:",accuracy)