import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
def network():
sess=tf.Session()
xs=tf.placeholder(tf.float32,[None,784])#24*24
ys=tf.placeholder(tf.float32,[None,10])
x_image=tf.reshape(xs,[-1,28,28,1])
y_image=tf.reshape(ys,[-1,10])
for i in range(1):
batch_xs,batch_ys=mnist.train.next_batch(10)
x=sess.run(x_image,feed_dict={xs:batch_xs})
y=sess.run(y_image,feed_dict={ys:batch_ys})
shape=x.shape
for i in range(10):
tmp_pic=np.zeros((28,28,3))
for a in range(shape[1]):
for b in range(shape[2]):
tmp_pic[a,b,:]=x[i,a,b,:]*3
print(tmp_pic.shape)
plt.imshow(tmp_pic)
#真实值
print(np.argwhere(y[i]==max(y[i])))
#前面不能有show否则保存空白图片
plt.savefig(str(i)+'.jpg')
plt.show()
if __name__=='__main__':
network()
读取存储mnist待识别得数字图片
最新推荐文章于 2019-06-20 20:59:29 发布