mnist = input_data.read_data_sets("./mnist/", one_hot=True, reshape=False) # reshape=False (?, 28,28,1)
# reshape之后就是(?, 784)
参考文献:
https://www.cnblogs.com/qqw-1995/p/9805025.html
import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('./mnist', one_hot=True) # 将数组张换成图片形式 print(mnist.train.images.shape) # (55000, 784) print(mnist.train.labels.shape) # (55000, 10) print(mnist.train.labels[1]) # [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] image = mnist.train.images[1].reshape(28, 28) fig = plt.figure("图片展示") plt.imshow(image,cmap='gray') plt.axis('off')#不显示坐标尺寸 plt.show()
from tensorflow.examples.tutorials.mnist import input_data import math import matplotlib.pyplot as plt import numpy as np mnist = input_data.read_data_sets('./mnist', one_hot=True) # 画单张mnist数据集的数据 def drawdigit(position,image, title): plt.subplot(*position) plt.imshow(image, cmap='gray_r') plt.axis('off') plt.title(title) # 取一个batch的数据,然后在一张画布上画batch_size个子图 def batchDraw(batch_size): images, labels = mnist.train.next_batch(batch_size) row_num = math.ceil(batch_size ** 0.5) # 向上取整 column_num = row_num plt.figure(figsize=(row_num, column_num)) # 行.列 for i in range(row_num): for j in range(column_num): index = i * column_num + j if index < batch_size: position = (row_num, column_num, index+1) image = images[index].reshape(28, 28) # 取出列表中最大数的索引 title = 'actual:%d' % (np.argmax(labels[index])) drawdigit(position, image, title) if __name__ == '__main__': batchDraw(16) plt.show()