数据处理
import tensorflow as tf
import keras
from keras import layers
下载数据集(用vpn)
import keras.datasets.mnist as mnist
(train_image, train_label), (test_image, test_label) = mnist.load_data()
查看训练集
train_image.shape
训练集是60000张28*28像素的图片组成
图像的数据的shape
hight width channel(黑白图像是1,彩色图像是3)
conv2d 要求数据是一个高 宽 channel 形状的图像
conv2d:图片输入形状:batch(有多少张图片) ,高,宽, channel
为了给conv2d输入图像,我们需要将图片扩宽一个用来表示channel的维度
dense:图片输入形状:batch(有多少张图片) ,data
train_image = np.expand_dims(train_image, axis=-1) ##axis=-1表示在最后一个维度上扩增
此时再看训练集图像形状
在这里插入代码片
此时变为一个四维数据形状
测试集也扩宽维度
test_image = np.expand_dims(test_image, axis=-1)
初始化模型
model = keras.Sequential()
添加层,构建网络
##卷积层
model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28,