代码来自这个博主,下面是我自己的一些理解。
链接: http://blog.nodetopo.com/2020/03/23/tensorflow2学习四/.
这是tf2.0的代码,tf1.*可以去看博主莫烦的代码。
使用tf.keras.layers.Conv操作要知道Conv1D与Conv2D的区别,简单的说Conv1D只进行纵向的移动提取特征,Conv2D则是横向以及纵向的移动来提取特征。
Flatten层用来将输入“压平”,即把多维的输入一维化。常用在卷积层到全链接层的过度,不影响batch的大小。
Dense是一个全连接层,它的激活函数默认为是linear线性函数,64表示样本参数输出大小。一般全连接层有两层或者两层以上,这是因为两层及以上可以很好地解决非线性问题,就是能根据多个不同因素去实现准确的分类。activation: 激活函数 (详见 activations)。 若不指定,则不使用激活函数 (即,「线性」激活: a(x) = x)。
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
#下载数据
(train_x, train_y), (test_x, test_y) = keras.datasets.mnist.load_data(path='mnist.npz')
#标准化
train_x, test_x = train_x / 255.0, test_x / 255.0
train_x, test_x = train_x[:, :, :, np.newaxis], test_x[:, :, :, np.newaxis]
#构建模型
model = keras.models.Sequential([
#valid:表示不够卷积核大小的块,则丢弃;same表示不够卷积核大小的块就补0,所以输出和输入形状相同.28是卷积核数目,(3,3)是filter即卷积核的大小。
keras.layers.Conv2D(28, (3, 3), strides=1, padding='same', activation='relu', input_shape=(28, 28, 1)),
keras.layers.MaxPooling2D((2, 2)),#(2,2)是池化层的窗口大小。
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.Flatten(),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(10)
])
#运行模型
model.compile(optimizer='SGD',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_x, train_y, epochs=10, validation_data=(test_x, test_y))
test_loss, test_acc = model.evaluate(test_x, test_y, verbose=2)
print('loss', test_loss, '\naccuracy: ',test_acc)
``
- model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准。
- model.fit() fit函数参数说明:
x:输入数据。如果模型只有一个输入,那么x的类型是numpy。
y:标签,numpy array。
epochs:整数,训练模型迭代次数。
validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。
validation_data:形式为(X,y)的tuple,是指定的验证集。此参数将覆盖validation_spilt
- model.evaluate()函数
输入数据和标签,输出损失和精确度
x:输入数据
y:输入标签
verbose:0不显示进度条,1为显示进度条