实验步骤
1、导入库
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
print("Tensorflow版本是:",tf.__version__)
2、数据获取
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取
TensorFlow提供了数据集读取方法(1.x和2.0版本提供的方法不同)
mnist = tf.keras.datasets.mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
MNIST数据集文件在读取时如果指定目录下不存在,则会自动去下载,需等待 一定时间;如果已经存在了,则直接读取
3、数据集划分
total_num = len(train_images)
valid_split = 0.2
train_num = int(total_num*(1-valid_split))
train_x = train_images[:train_num]
train_y = train_labels[:train_num]
valid_x = train_images[train_num:]
valid_y = train_labels[train_num:]
test_x = test_images
test_y = test_labels
valid_x.shape
4、数据塑形
train_x = train_x.reshape(-1,784)
valid_x = valid_x.reshape(-1,784