用TensorFlow基于最近邻域法实现图像识别

1、导入编程库

import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as  plt
from PIL import Image
from tensorflow.examples.tutorials.mnist import input_data

2、创建会话,加载数据集

sess = tf.Session()
mnist = input_data.read_data_sets("E:\Python Project\mnist\MNIST_data", one_hot=True)

3、分割数据集

train_size = 100
test_size = 102
rand_train_indices = np.random.choice(len(mnist.train.images),train_size,replace =False)

rand_test_indices = np.random.choice(len(mnist.test.images),test_size,replace = False)
x_vals_train = mnist.train.images[rand_train_indices]
x_vals_test = mnist.test.images[rand_test_indices]

y_vals_train = mnist.train.labels[rand_train_indices]
y_vals_test = mnist.train.labels[rand_test_indices]

4、声明K值,批量大小,占位符等

k = 4
batch_size = 6
x_data_train = tf.placeholder(shape=[None, 784], dtype = tf.float32)
x_data_test = tf.placeholder(shape=[None, 784], dtype = tf.float32)
y_target_train = tf.placeholder(shape=[None,10], dtype = tf.float32)
y_target_test = tf.placeholder(shape=[None,10], dtype = tf.float32)

5、声明距离度量函数

distance = tf.reduce_sum(tf.abs(tf.subtract(x_data_train, tf.expand_dims(x_data_test,1))), reduction_indices=2)

6、找到最接近的top k图片和预测模型

top_x_xvals, top_k_indices = tf.nn.top_k(tf.negative(distance),k=k)
prediction_indices = tf.gather(y_target_train, top_k_indices)
count_of_predictions = tf.reduce_sum(prediction_indices, axis=1)
prediction = tf.argmax(count_of_predictions, axis=1)

7、遍历迭代,计算预测值,并将结果存储

test_output = []
actual_vals = []

for i in range(num_loops):
    min_index = i * batch_size
    max_index = min((i+1)*batch_size,len(x_vals_train))
    x_batch = x_vals_test[min_index:max_index]
    y_batch = y_vals_test[min_index:max_index]
    predictions = sess.run(prediction, feed_dict={x_data_train:x_vals_train,x_data_test:x_batch,y_target_train:y_vals_train,y_target_test:y_batch})
    test_output.extend(predictions)
    actual_vals.extend(np.argmax(y_batch, axis=1))
    

8、计算准确度

accuarcy = sum([1./test_size for i in range(test_size) if test_output[i] == actual_vals[i]])

9、绘制最后批次的计算结果

actuals = np.argmax(y_batch, axis=1)

Nrows = 2
Ncols = 3
for i in range(len(actuals)):
    plt.subplot(Nrows, Ncols, i+1)
    plt.imshow(np.reshape(x_batch[i], [28,28]), cmap='Greys_r')
    plt.title('Actual: ' + str(actuals[i]) + ' Pred: ' + str(predictions[i]),
                               fontsize=10)
    frame = plt.gca()
    frame.axes.get_xaxis().set_visible(False)
    frame.axes.get_yaxis().set_visible(False)
    
plt.show()

9、运行结果
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值