判别器loss为0_TensorFlow 2.0 模型:多层感知机

本文介绍了如何在 TensorFlow 2.0 中构建和训练多层感知机(MLP),使用 MNIST 数据集,涵盖了数据预处理、模型构建、损失函数计算、优化器应用和模型评估的全过程。
摘要由CSDN通过智能技术生成

文 /  李锡涵,Google Developers Expert

本文节选自《简单粗暴 TensorFlow 2.0》

ddc9009ba4f3bf76bde31590b3eaeace.gif

在 上一篇文章 里,我们简要介绍了 TensorFlow 2.0 中建立模型类的方法。本文即以多层感知机 (Multilayer Perceptron, MLP),或者说 “多层全连接神经网络” 为例,给出一个具体示例,详细介绍 TensorFlow 2.0 的模型构建、训练、评估的全流程。在这一部分,我们依次进行以下步骤:

  • 使用 tf.keras.datasets 获得数据集并预处理

  • 使用 tf.keras.Modeltf.keras.layers 构建模型

  • 构建模型训练流程,使用 tf.keras.losses 计算损失函数,并使用 tf.keras.optimizer 优化模型

  • 构建模型评估流程,使用 tf.keras.metrics 计算评估指标

基础知识和原理

UFLDL 教程 | 神经网络 一节

斯坦福课程 CS231n: Convolutional Neural Networks for Visual Recognition 中的 “Neural Networks Part 1 ~ 3” 部分。

注:神经网络 链接

http://ufldl.stanford.edu/tutorial/supervised/MultiLayerNeuralNetworks/

CS231n 链接

http://cs231n.github.io/

这里,我们使用多层感知机完成 MNIST 手写体数字图片数据集 [LeCun1998] 的分类任务。 022f6794bde943fa2ad80e80f7d4e35a.png

MNIST 手写数字图片示例

数据获取及预处理:tf.keras.datasets

先进行预备工作,实现一个简单的 MNISTLoader 类来读取 MNIST 数据集数据。这里使用了 tf.keras.datasets 快速载入 MNIST 数据集。
 1class MNISTLoader():
2    def __init__(self):
3        mnist = tf.keras.datasets.mnist
4        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
5        # MNIST中的图像默认为uint8(0-255的数字)。以下代码将其归一化到0-1之间的浮点数,并在最后增加一维作为颜色通道
6        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
7        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
8        self.train_label = self.train_label.astype(np.int32)    # [60000]
9        self.test_label = self.test_label.astype(np.int32)      # [10000]
10        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]
11
12    def get_batch(self, batch_size):
13        # 从数据集中随机取出batch_size个元素并返回
14        index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
15        return self.train_data[index, :], self.train_label[index]

提示

mnist = tf.keras.datasets.mnist 将从网络上自动下载 MNIST 数据集并加载。如果运行时出现网络连接错误,可以从 ht

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值