#using nearest to classify mnist handwrite dataset
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/root/data/",one_hot = True)
Xtr,Ytr = mnist.train.next_batch(5000)
Xte,Yte = mnist.test.next_batch(200)
xtr = tf.placeholder("float",[None,784])
xte = tf.placeholder("float",[784])
distance = tf.reduce_sum(tf.abs(tf.add(xtr,tf.neg(xte))),reduction_indices = 1)
pred = tf.arg_min(distance, 0)
accuracy = 0
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
for i in range(len(Xte)):
//tf.Session.run()函数返回值为pred的执行结果。如果pred是一个元素就返回一个值;若pred是一个list,则返回list的值,若pred是一个字典类型,则返回和pred同keys的字典。
nn_index = sess.run(pred, feed_dict = {xtr:Xtr, xte:Xte[i,:]})
print "Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i])
if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
accuracy += 1./len(Xte)
print "Done"
print "Accuracy:", accuracy
<img data-cke-saved-src="https://img-blog.csdn.net/20161231204250023?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvd3p3MTIzMTU=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" src="https://img-blog.csdn.net/20161231204250023?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvd3p3MTIzMTU=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" alt="" />