TensorFlow2.0教程-自定义训练实战(非tf.keras)

TensorFlow2.0教程-自定义训练实战(非tf.keras)

本教程我们将使用TensorFlow来实现鸢尾花分类。整个过程包括:构建模型、模型训练、模型预测。其中,虽然网络层仍是使用keras.layer的网络,但训练过程没有使用keras的方法,而是使用tensorflow2中eager模式的自动求导方法构造的。

愿文地址:https://doit-space.blog.csdn.net/article/details/95041068

最全Tensorflow 2.0 入门教程持续更新:https://blog.csdn.net/qq_31456593/article/details/88606284

完整tensorflow2.0教程代码请看 https://github.com/czy36mengfei/tensorflow2_tutorials_chinese (欢迎star)

本教程主要由tensorflow2.0官方教程的个人学习复现笔记整理而来,中文讲解,方便喜欢阅读中文教程的朋友,官方教程:https://www.tensorflow.org

导入相关库

导入TensorFlow和其他所需的Python模块。 默认情况下,TensorFlow2使用急切执行来程序,会立即返回结果。

from __future__ import absolute_import, division, print_function, unicode_literals
import os
import matplotlib.pyplot as plt
import tensorflow as tf
print('tf version:', tf.__version__)
print('eager execution:', tf.executing_eagerly())
tf version: 2.0.0-alpha0
eager execution: True

鸢尾花分类问题

想象一下,你是一名植物学家,正在寻找一种自动化的方法来对你找到的每种鸢尾花进行分类。 机器学习提供了许多算法来对花进行统计分类。 例如,复杂的机器学习程序可以基于照片对花进行分类。而这里,我们将根据萼片和花瓣的长度和宽度测量来对鸢尾花进行分类。

鸢尾花有300多种类别,但我们的这里主要对以下三种进行分类:

  • Iris setosa
  • Iris virginica
  • Iris versicolor
    # [外链图片转存失败(img-YiS1cGTo-1562512656938)(https://www.tensorflow.org/images/iris_three_species.jpg)]

幸运的是,有人已经用萼片和花瓣测量创建了120个鸢尾花的数据集。 这是一个流行的初学者机器学习分类问题的经典数据集。

下载数据集
使用tf.keras.utils.get_file函数下载训练数据集文件。 这将返回下载文件的文件路径。

train_dataset_url = "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv"
train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),
                                          origin=train_dataset_url)
print('下载数据至:', train_dataset_fp)
下载数据至: /root/.keras/datasets/iris_training.csv

检查数据

此数据集iris_training.csv是一个纯文本文件,用于存储格式为逗号分隔值(CSV)的表格数据。 使用head -n5命令在前五个条目中取一个峰值:

!head -n5 {
   train_dataset_fp}
120,4,setosa,versicolor,virginica
6.4,2.8,5.6,2.2,2
5.0,2.3,3.3,1.0,1
4.9,2.5,4.5,1.7,2
4.9,3.1,1.5,0.1,0

从数据集的此视图中,请注意以下内容:

第一行是包含有关数据集的信息的标题:
总共有120个例子。 每个示例都有四个特征和三个可能的标签名称之一。
后续行是数据记录,每行一个示例,其中:
前四个字段是特征:这些是示例的特征。 这里,字段包含代表花卉测量值的浮点数。
最后一列是标签:这是我们想要预测的值。 对于此数据集,它是与花名称对应的整数值0,1或2。

column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
# 获取特征和标签名
feature_name = column_names[:-1]
label_name = column_names[-1]

每个标签都与字符串名称相关联(例如,“setosa”),但机器学习通常依赖于数值。使用标签数字来映射类别,例如:

  • 0:Iris setosa:
  • 1:Iris versicolor
  • 2:Iris virginica
class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']

创建一个 tf.data.Dataset

TensorFlow的数据集API处理许多将数据加载到模型中的常见情况。这是一个高级API,用于读取数据并将其转换为用于训练的数据类型。

由于数据集是CSV格式的文本文件,因此需要使用make_csv_dataset函数将数据解析为合适的格式。由于此函数为训练模型生成数据,因此默认行为是对数据(shuffle=True, shuffle_buffer_size=10000)进行混洗,并永远重复数据集(num_epochs=None)。同时还需要设置batch_size参数。

batch_size=32
train_dataset = tf.data.experimental.make_csv_dataset(
    train_dataset_fp,
    batch_size,
    column_names=column_names,
    label_name=label_name,
    num_epochs=1
)

该make_csv_dataset函数返回tf.data.Dataset的(features, label)对,其中features是一个字典:{‘feature_name’: value}

而这些Dataset对象是可迭代的。

features, labels = next(iter(train_dataset))
print(features)
OrderedDict([('sepal_length', <tf.Tensor: id=64, shape=(32,), dtype=float32, numpy=
array([7.6, 6.9, 7.2, 5. , 6.7, 4.8, 5.4, 5.1, 7.7, 6. , 6.3, 7.4, 5.2,
       7.2, 6.7, 6.1, 5. , 4.9, 6.2, 4.5, 6.6, 6. , 5.5, 6.3, 4.8, 6.7,
       6.1, 5.6, 7.3, 6.9, 5
  • 5
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值