basic_example : nearest neighbor algorithm
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 20 19:26:25 2017
@author: wu
"""
# 引入模块
from __future__ import print_function
import tensorflow as tf
import numpy as np
#下载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
#获取训练数据和测试数据
Xtr, Ytr = mnist.train.next_batch(5000)
Xte, Yte = mnist.train.next_batch(200)
#TensorFlow的数据图的输入
xtr = tf.placeholder("float", [None, 784])
xte = tf.placeholder("float", [784])
#use L1 distance
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices = 1)
pred = tf.arg_min(distance, 0)
accuracy = 0.
init = tf.global_variables_initializer()
#Launch graph
with tf.Session() as sess:
sess.run(init)
for i in range(len(Xte)):#测试集图片迭代测试
nn_index = sess.run(pred, feed_dict = {xtr: Xtr, xte: Xte[i, :]})
print("Test", i, "Prediction: ", np.argmax(Ytr[nn_index]),"True class: ",np.argmax(Yte[i]))
#如果测试集图片和训练集图片一样,计算精确率
if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
accuracy += 1./len(Xte)
print("Done!")
print("Accuracy: ",accuracy)
部分运行结果截图