from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from mnist import dataset
from mnist.utils.flags import core as flags_core
from mnist.utils.logs import hooks_helper
from mnist.utils.misc import distribution_utils
from mnist.utils.misc import model_helpers
LEARNING_RATE = 1e-4
def create_model(data_format):
"""Model to recognize digits in the MNIST dataset.
Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But uses the tf.keras API.
Args:
data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
typically faster on GPUs while 'channels_last' is typically faster on
CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
Returns:
A tf.keras.Model.
"""
#定义数据格式
if data_format == 'channels_first':
input_shape = [1, 28, 28]
else:
assert data_format == 'channels_last'
input_shape = [28, 28, 1]
# 定义模块开始
l = tf.keras.layers
#定义pool层
max_pool = l.MaxPooling2D((2, 2), (2, 2), padding=
tensorflow入门笔记-03 mnisy.py注解
最新推荐文章于 2021-04-01 01:16:53 发布
本文是TensorFlow入门系列的第三篇,我们将通过MNIST数据集进行实战,了解如何加载数据、构建神经网络模型并训练。内容包括数据预处理、搭建CNN模型、损失函数和优化器的选择,以及模型的训练与评估。
摘要由CSDN通过智能技术生成