基于TensorFlow的最近邻(NN)分类器——以MNIST识别为例

一、最近邻分类理论

可以参考:这里写链接内容
https://wenku.baidu.com/view/d0924523a45177232e60a201.html
https://www.cnblogs.com/bugsheep/p/7879407.html
http://advkwo.blog.163.com/blog/static/2541350720106532518960/

二、TF在CPU上实现NN分类

具体代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data   
'''========load data========'''
#mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  #'MNIST_data'设置请参考前几篇文章

# 获得训练样本个数
#train_nums = mnist.train.num_examples

# 读取所有训练样本和测试样本
#X_train = mnist.train.images   
#X_test = mnist.test.images
#Y_train = mnist.train.labels
#Y_test = mnist.test.labels

# 批量读取部分样本
X_train,Y_train = mnist.train.next_batch(1000)
X_test,Y_test = mnist.test.next_batch(200)

'''========0.定义常量========'''
insize = 784  #input size

'''计算图输入占位符'''
xs = tf.placeholder(tf.float32,[None,insize])
xst  = tf.placeholder(tf.float32,[insize])

'''使用 L1 距离进行最近邻计算'''
# L1:dist = sum(|X1-X2|)  或 L2:dist=sqrt(sum(|X1-X2|^2))
dist = tf.reduce_sum(tf.abs(tf.add(xs,tf.negative(xst))),
                     reduction_indices=1)
#或dist = tf.reduce_sum(tf.abs(tf.subtract(xtrain, xtest))), axis=1)

# 预测: 获得最小距离的索引,然后根据此索引的类标和正确的类标进行比较
index = tf.arg_min(dist,0)

# 初始化所有变量
init = tf.global_variables_initializer()

#定义一个正确率计算器
Accuracy = 0;

# 执行会话
with tf.Session() as sess:
    sess.run(init) 
    # 只能循环地对测试样本进行预测
    for i in range(len(X_test)):  
        #print('Dist=',sess.run(dist,feed_dict={xs:X_train,xst:X_test[i,:]}))
        id = sess.run(index,feed_dict={xs:X_train,xst:X_test[i,:]})
        # 计算预测标签和正确标签用于比较
        Predict_label = np.argmax(Y_train[id])
        True_label = np.argmax(Y_test[i])

        print("Test Sample",i,"Prediction label:",Predict_label,
              "True Class label:",True_label)

        # 计算精确度
        if Predict_label == True_label:
            Accuracy +=1
    print("Accuracy=",Accuracy/len(X_test))

方法二:将其读数据和NN分类单独写出函数,方便后期调用

import tensorflow as tf
def load_mnist_data(filename,isbatch=0,train_nums=1000,test_nums=200):
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets(filename, one_hot=True)
       #2、批量获取样本
    if isbatch==1:
        X_train,Y_train = mnist.train.next_batch(train_nums)
        X_test,Y_test = mnist.test.next_batch(test_nums)
        return X_train,Y_train,X_test,Y_test
    else:
        #1、获取全部样本
        X_train = mnist.train.images   #[1:10]
        X_test = mnist.test.images
        Y_train = mnist.train.labels
        Y_test = mnist.test.labels
        return X_train,Y_train,X_test,Y_test

def NN_Classifier(X_train,Y_train,X_test,Y_test,dims=784,dist_metric='L1'):
    # 计算图输入占位符
    xs = tf.placeholder(tf.float32,[None,dims])
    xst  = tf.placeholder(tf.float32,[dims])  
    # 使用 L1 距离进行最近邻计算
    # L1:dist = sum(|X1-X2|)  或 L2:dist=sqrt(sum(|X1-X2|^2))
    dist = tf.reduce_sum(tf.abs(tf.add(xs,tf.negative(xst))),
                         reduction_indices=1)
    #或dist = tf.reduce_sum(tf.abs(tf.subtract(xtrain, xtest))), axis=1)

    # 预测: 获得最小距离的索引,然后根据此索引的类标和正确的类标进行比较
    index = tf.arg_min(dist,0)

    # 初始化所有变量
    init = tf.global_variables_initializer()    

    #定义一个正确率计算器
    Accuracy = 0

    # 执行会话
    with tf.Session() as sess:
        sess.run(init) 
        # 只能循环地对测试样本进行预测
        for i in range(len(X_test)):  
            id = sess.run(index,feed_dict={xs:X_train,xst:X_test[i,:]})
            # 计算预测标签和正确标签用于比较
            Predict_label = np.argmax(Y_train[id])
            True_label = np.argmax(Y_test[i])

            print("Test Sample",i,"Prediction label:",Predict_label,
                  "True Class label:",True_label)

            # 计算精确度
            if Predict_label == True_label:
                Accuracy +=1
        print("Accuracy=",Accuracy/len(X_test))    

    return Accuracy    

if __name__ == '__main__':  
    X_train,Y_train,X_test,Y_test = load_mnist_data("MNIST_data",isbatch=1,train_nums=1000,test_nums=200)    
    Accuracy =  NN_Classifier(X_train,Y_train,X_test,Y_test,dims=784,dist_metric='L1') 
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值