import tensorflow as tf
import matplotlib.pyplot as plt
# 读取MNIST数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
# 获得数据集的个数
# train_nums=55000
# validation_nums=5000
# test_nums=10000
train_nums = mnist.train.num_examples
validation_nums = mnist.validation.num_examples
test_nums = mnist.test.num_examples
# 获得数据值
# 训练集数据大小train_data.shape=(55000, 784)
# 一副图像的大小train_data[0].shape=(784,)
train_data = mnist.train.images
val_data = mnist.validation.images
test_data = mnist.test.images
# 训练集标签数组大小train_labels.shape=(55000, 10)
# 一副图像的标签大小train_labels[1].shape=(10,)
# 一副图像的标签值train_labels[0]=[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
train_labels = mnist.train.labels #(55000,10)
val_labels = mnist.validation.labels #(5000,10)
test_labels = mnist.test.labels #(10000,10)
# 使用next_batch(batch_size)批量获取数据和标签
# 每次批量训练100幅图像
batch_size = 100
# batch_xs.shape =(100, 784),batch_ys.shape = (100,10)
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
# 显示图像
plt.figure()
for i in range(3):
im = train_data[i].reshape(28,28)
im = batch_xs[i].reshape(28,28)
plt.imshow(im,'gray')
plt.pause(0.0000001)
plt.show()
TensorFlow机器学习数据集
最新推荐文章于 2021-03-09 11:08:52 发布