tensorflow2.0使用GradientTape自定义训练鸢尾花分类器

本文介绍了如何使用Tensorflow 2.0的GradientTape进行自定义训练,以鸢尾花数据集为例,详细展示了模型定义、损失函数构建、训练过程及预测结果。
摘要由CSDN通过智能技术生成

  本文根据谷歌Tensorflow的官方事例进行介绍。官方事例在自定义训练:演示中查看。通过该实例可以更好的了解GradientTape的使用,尤其在自定义网络损失函数时尤其重要。
  在原事例中使用的是url下载谷歌的鸢尾花数据集,由于网络原因下载失败,于是选用sklearn中的鸢尾花数据集。具体实现代码如下:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
iris = load_iris()
features = iris["data"]  #shape: (150, 4)
labels = iris["target"]  #shape: (150,)
class_names = iris["target_names"]   #['setosa' 'versicolor' 'virginica']

  定义模型、损失函数、建立梯度带实现代码如下所示:

# 建立简单的全连接网络
model = tf.keras.Sequential([
  tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(4,)),  # 需要给出输入的形式
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(3)
])

# 建立损失函数
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_lo
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值