1、绪论
在TensorFlow中,fit() 方法是一个用于训练模型的便捷函数,它封装了训练循环(training loop)的许多常见步骤,如前向传播(forward pass)、计算损失(loss)、反向传播(backward pass)以及模型权重的更新。然而,有时候你可能想要对训练过程进行更细粒度的控制,或者添加一些自定义的步骤。
当进行监督学习时,可以使用fit()
方法,并且一切都会顺利进行。
但是,当需要控制每一个小细节时,就可以完全从头开始编写自己的训练循环。
但如果需要一个自定义的训练算法,但又想从fit()
的便捷功能中受益,比如回调(callbacks)、内置的分布支持(built-in distribution support)或步骤融合(step fusing)时,又该如何呢?
Keras的一个核心原则是复杂性的逐步展现。总是能够逐步深入到更低级别的工作流程中。如果高级功能不完全符合你的测试用例,也不会突然陷入困境。我们可以在保留相应级别的高级便利性的同时,对细节获得更多的控制权。
当需要自定义fit()
的行为时,我们应该重写Model类的训练步骤函数。这是fit()
在处理每一批数据时调用的函数。然后你就可以像平常一样调用fit()
——而它将会运行你自己的学习算法。
请注意,这种模式并不会阻止你使用函数式API构建模型。无论你是构建Sequential模型、函数式API模型还是子类化模型,都可以采用这种方法。
2、准备工作
#2.1 基础设置
开始操作前请按照如下进行基础设置
import os
# This guide can only be run with the TF backend.
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
from keras import layers
import numpy as np
2.2 操作示例
以下是一个使用TensorFlow来自定义fit()的示例:
首先需要创建一个新的类,该类继承自keras.Model
。
然后重写train_step(self, data)
这个方法。
之后我们返回数据字典,该字典将度量指标名称(包括损失)映射到它们的当前值。
输入参数data
是传递给fit
方法的训练数据:
- 如果你通过调用
fit(x, y, ...)
传递NumPy数组,那么data
将是元组(x, y)
- 如果你通过调用
fit(dataset, ...)
传递一个tf.data.Dataset
,那么data
将是dataset
在每个批次中产生的数据。
在train_step()
方法的主体中,我们实现了一个常规的训练更新过程,类似于你已经熟悉的过程。重要的是,我们通过self.compute_loss()
计算损失,该方法封装了在compile()
方法中传递的损失函数。
类似地,我们调用metric.update_state(y, y_pred)
来更新在compile()
方法中传递的度量指标的状态,并在最后通过self.metrics
查询结果来获取它们的当前值。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compute_loss(y=y, y_pred=y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply(gradients, trainable_vars)
# Update metrics (includes the metric that tracks the loss)
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {
m.name: m.result() for m in self.metrics}
运行代码,看看输出结果
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.5089 - loss: 0.3778
Epoch 2/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 318us/step - mae: 0.3986 - loss: 0.2466
Epoch 3/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 372us/step - mae: 0.3848 - loss: 0.2319
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699222602.443035 1 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
<keras.src.callbacks.history.History at 0x2a5599f00>
3、底层操作方法
在操作过程中,可以在compile()
方法中省略损失函数的传递,而是在train_step
中手动完成所有操作。对于度量指标(metrics)也是如此。
以下是一个更加底层操作的示例,它仅使用compile()
方法来配置优化器:
首先,我们在__init__()
方法中创建度量指标实例来跟踪损失和平均绝对误差(MAE)分数。
然后,我们实现一个自定义的train_step()
,更新这些度量指标的状态(通过调用它们的update_state()
方法),接着查询它们(通过result()
方法)来返回当前平均值,以便进度条显示并传递给任何回调函数。
请注意,在每个epoch之间,我们需要调用度量指标的reset_states()
方法!否则,调用result()
会返回从训练开始以来的平均值,而我们通常处理的是每个epoch的平均值。幸运的是,框架可以为我们完成这一操作:只需在模型的metrics
属性中列出你希望重置的任何度量指标对象。在每个fit()
epoch的开始或调用evaluate()
时,模型会自动调用这些对象上的reset_states()
方法。
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args,