CNN分类mnist

代码来自这个博主,下面是我自己的一些理解。
链接: http://blog.nodetopo.com/2020/03/23/tensorflow2学习四/.

这是tf2.0的代码,tf1.*可以去看博主莫烦的代码。

使用tf.keras.layers.Conv操作要知道Conv1D与Conv2D的区别,简单的说Conv1D只进行纵向的移动提取特征,Conv2D则是横向以及纵向的移动来提取特征。
Flatten层用来将输入“压平”,即把多维的输入一维化。常用在卷积层到全链接层的过度,不影响batch的大小。
Dense是一个全连接层,它的激活函数默认为是linear线性函数,64表示样本参数输出大小。一般全连接层有两层或者两层以上,这是因为两层及以上可以很好地解决非线性问题,就是能根据多个不同因素去实现准确的分类。activation: 激活函数 (详见 activations)。 若不指定,则不使用激活函数 (即,「线性」激活: a(x) = x)。

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt

#下载数据
(train_x, train_y), (test_x, test_y) = keras.datasets.mnist.load_data(path='mnist.npz')

#标准化
train_x, test_x = train_x / 255.0, test_x / 255.0
train_x, test_x = train_x[:, :, :, np.newaxis], test_x[:, :, :, np.newaxis]

#构建模型
model = keras.models.Sequential([
    #valid:表示不够卷积核大小的块,则丢弃;same表示不够卷积核大小的块就补0,所以输出和输入形状相同.28是卷积核数目,(3,3)是filter即卷积核的大小。
    keras.layers.Conv2D(28, (3, 3), strides=1, padding='same', activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPooling2D((2, 2)),#(2,2)是池化层的窗口大小。
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2)),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Flatten(),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(10)
])

#运行模型
model.compile(optimizer='SGD',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(train_x, train_y, epochs=10, validation_data=(test_x, test_y))
test_loss, test_acc = model.evaluate(test_x,  test_y, verbose=2)
print('loss', test_loss, '\naccuracy: ',test_acc)
``

 - model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准。

 - model.fit() fit函数参数说明:
x:输入数据。如果模型只有一个输入,那么x的类型是numpy。
y:标签,numpy array。
epochs:整数,训练模型迭代次数。

validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。

validation_data:形式为(X,y)的tuple,是指定的验证集。此参数将覆盖validation_spilt

 - model.evaluate()函数
 输入数据和标签,输出损失和精确度
x:输入数据
y:输入标签
verbose:0不显示进度条,1为显示进度条
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值