tensorflow 使用nearest最邻近算法 分类mnist数据库

#using nearest to classify mnist handwrite dataset

import numpy as np
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/root/data/",one_hot = True)

Xtr,Ytr = mnist.train.next_batch(5000)
Xte,Yte = mnist.test.next_batch(200)

xtr = tf.placeholder("float",[None,784])
xte = tf.placeholder("float",[784])

distance = tf.reduce_sum(tf.abs(tf.add(xtr,tf.neg(xte))),reduction_indices = 1)
pred = tf.arg_min(distance, 0)

accuracy = 0
 
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
 
    for i in range(len(Xte)):
        //tf.Session.run()函数返回值为pred的执行结果。如果pred是一个元素就返回一个值;若pred是一个list,则返回list的值,若pred是一个字典类型,则返回和pred同keys的字典。
       nn_index = sess.run(pred, feed_dict = {xtr:Xtr, xte:Xte[i,:]})
       print "Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i])

       if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
          accuracy += 1./len(Xte)
   
    print "Done"
    print "Accuracy:", accuracy
<img data-cke-saved-src="https://img-blog.csdn.net/20161231204250023?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvd3p3MTIzMTU=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" src="https://img-blog.csdn.net/20161231204250023?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvd3p3MTIzMTU=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" alt="" />

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值