TensorFlow----最近算法nearest_neighbor

  1. import numpy as np  
  2. import tensorflow as tf  
  3.   
  4. # Import MINST data  
  5. import input_data  
  6. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)  
  7. #<a target="_blank" href="https://github.com/tensorflow/tensorflow/blob/r0.7/tensorflow/examples/tutorials/mnist/input_data.py">input_data.py</a>  Dataset downloading Loading the entire dataset into numpy array:  
  8. #这里主要是导入数据,数据通过input_data.py已经下载到/tmp/data/目录之下了,这里下载数据的时候,需要提前用浏览器尝试是否可以打开  
  9. #http://yann.lecun.com/exdb/mnist/,如果打不开,下载数据阶段会报错。而且一旦数据下载中断,需要将之前下载的未完成的数据清空,重新  
  10. #进行下载,否则会出现CRC Check错误。read_data_sets是input_data.py里面的一个函数,主要是将数据解压之后,放到对应的位置。  
  11. # In this example, we limit mnist data  
  12. Xtr, Ytr = mnist.train.next_batch(5000#5000 for training (nn candidates)  
  13. Xte, Yte = mnist.test.next_batch(200#200 for testing  
  14. #mnist.train.next_batch,其中train和next_batch都是在input_data.py里定义好的数据项和函数。此处主要是取得一定数量的数据。  
  15. #<code>next_batch</code> 函数可以遍历整个数据集,只返回所需的样本数据集的一部分(为了节省内存和避免加载整个数据集)。  
  16.   
  17. # Reshape images to 1D  
  18. Xtr = np.reshape(Xtr, newshape=(-128*28))  
  19. Xte = np.reshape(Xte, newshape=(-128*28))  
  20. #将二维的图像数据一维化,利于后面的相加操作。  
  21. # tf Graph Input  
  22. xtr = tf.placeholder("float", [None784])  
  23. xte = tf.placeholder("float", [784])  
  24. #设立两个空的类型,并没有给具体的数据。这也是为了基于这两个类型,去实现部分的graph。  
  25.   
  26. # Nearest Neighbor calculation using L1 Distance  
  27. # Calculate L1 Distance  
  28. distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)  
  29. # Predict: Get min distance index (Nearest neighbor)  
  30. pred = tf.arg_min(distance, 0)  
  31. #最近邻居算法,算最近的距离的邻居,并且获取该邻居的下标,这里只是基于空的类型,实现的graph,并未进行真实的计算。  
  32. accuracy = 0.  
  33. # Initializing the variables  
  34. init = tf.initialize_all_variables()  
  35. #初始化所有的变量和未分配数值的占位符,这个过程是所有程序中必须做的,否则可能会读出随机数值。  
  36. # Launch the graph  
  37. with tf.Session() as sess:  
  38.     sess.run(init)  
  39.   
  40.     # loop over test data  
  41.     for i in range(len(Xte)):  
  42.         # Get nearest neighbor  
  43.         nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})  
  44.         # Get nearest neighbor class label and compare it to its true label  
  45.         print "Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i])  
  46.         # Calculate accuracy  
  47.         if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):  
  48.             accuracy += 1./len(Xte)  
  49.     print "Done!"  
  50.     print "Accuracy:", accuracy  
  51. #for循环迭代计算每一个测试数据的预测值,并且和真正的值进行对比,并计算精确度。该算法比较经典的是不需要提前训练,直接在测试阶段进行识别 




源码地址:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值