本文根据谷歌Tensorflow的官方事例进行介绍。官方事例在自定义训练:演示中查看。通过该实例可以更好的了解GradientTape的使用,尤其在自定义网络损失函数时尤其重要。
在原事例中使用的是url下载谷歌的鸢尾花数据集,由于网络原因下载失败,于是选用sklearn中的鸢尾花数据集。具体实现代码如下:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
iris = load_iris()
features = iris["data"] #shape: (150, 4)
labels = iris["target"] #shape: (150,)
class_names = iris["target_names"] #['setosa' 'versicolor' 'virginica']
定义模型、损失函数、建立梯度带实现代码如下所示:
# 建立简单的全连接网络
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(4,)), # 需要给出输入的形式
tf.keras.layers.Dense(10, activation=tf.nn.relu),
tf.keras.layers.Dense(3)
])
# 建立损失函数
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_lo