Keras深度学习框架第八讲:编写自定义训练和评估循环

1、绪论

Keras 提供了默认的训练和评估循环,即 fit()evaluate() 方法。这些方法的使用在《使用内置方法的训练与评估》一节中已经进行了详细的介绍。

如果希望在保持 fit() 方法的便利性的同时自定义模型的学习算法(例如,使用 fit() 方法训练 GAN),可以通过继承 Model 类并实现自己的 train_step() 方法来实现。这个方法在 fit() 方法调用期间会被反复调用。

然而,如果想要对训练和评估过程进行非常底层的控制,我们就应该从头开始编写自己的训练和评估循环。

当我们从头开始编写训练和评估循环时,需要手动处理数据的批次迭代、前向传播、损失计算、反向传播(梯度计算)、参数更新等步骤。从头编写训练和评估循环为程序员提供了更大的灵活性,但也需要更多的代码和更复杂的逻辑。

2、系统设置

在JAX中参照以下代码示例进行设置

import os

# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"

import jax

# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np

在TensorFlow中参照以下代码示例进行设置

import time
import os

# This guide can only be run with the TensorFlow backend.
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras
import numpy as np

在Pytorch中参照以下代码示例进行设置

import os

# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras
import numpy as np

3、操作范例

要编写一个自定义的训练循环,我们需要以下要素:

  1. 训练的模型。
  2. 优化器。可以使用keras.optimizers中的优化器,或者来自optax包的优化器。
  3. 损失函数。
  4. 数据集。
    首先,让我们获取模型和MNIST数据集:
def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = get_model()

# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

接下来,定义损失函数和优化器。在本示例中,我们将使用Keras的优化器。

# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

3.1 JAX中从头开始编写训练和评估循环

在JAX中,梯度是通过元编程来计算的:可以调用jax.grad(或者jax.value_and_grad)函数,该函数对另一个函数进行操作,以创建用于计算该第一个函数梯度的函数。

因此,首先我们需要一个返回损失值的函数。这个函数将用于生成梯度函数。它可能看起来像这样:

def compute_loss(x, y):
    ...
    return loss

定义了损失函数之后可以通过元编程来计算梯度,如下所示:

grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)

通常情况下我们不仅想要获取梯度值,还想要获取损失值。可以通过使用 jax.value_and_grad 来进行操作:

grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)
3.1.1 适应JAX计算的无状态性

在JAX中,所有内容都必须是无状态的函数——因此我们的损失计算函数也必须是无状态的。这意味着所有的Keras变量(例如权重张量)都必须作为函数输入传递,并且在前向传递期间已经更新的任何变量都必须作为函数输出返回。这个函数不能有副作用。

在前向传递期间,Keras模型的非可训练变量可能会被更新。这些变量可能是RNG种子状态变量或BatchNormalization统计信息等。我们需要返回这些变量。因此,编写训练循环时需要类似以下的操作:

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    ...
    return loss, non_trainable_variables

一旦定义了上面的函数,就可以通过在value_and_grad中指定has_aux参数来获取梯度函数:它告诉JAX损失计算函数返回的输出不仅仅是损失。请注意,损失应该始终是第一个输出。

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
    trainable_variables, non_trainable_variables, x, y
)

通过前面的讨论,我们已经建立了基础知识,现在可以讨论实现compute_loss_and_updates函数。Keras模型有一个stateless_call方法,它在这里会非常有用。它的工作方式与model.__call__类似,但要求你明确传递模型中所有变量的值,并且它不仅返回__call__的输出,还返回(可能已更新的)非可训练变量。

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    return loss, non_trainable_variables

通过如下操作获取梯度函数

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
3.1.2 训练步骤函数

接下来,让我们实现一个端到端的训练步骤函数,该函数将执行前向传播、计算损失、计算梯度,并使用优化器来更新可训练变量。这个函数也需要是无状态的,因此它将接收一个状态元组作为输入,该元组包含我们将要使用的每个状态元素:

trainable_variablesnon_trainable_variables:模型的变量。
optimizer_variables:优化器的状态变量,如动量累加器。
为了更新可训练变量,我们使用优化器的无状态方法 stateless_apply。它与 optimizer.apply() 功能等价,但总是需要传入 trainable_variablesoptimizer_variables。它返回更新后的可训练变量和更新后的 optimizer_variables

def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        grads, trainable_variables, optimizer_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )
3.1.3 使用 jax.jit 加速

默认情况下,JAX 操作是即时执行的,就像 TensorFlow 的即时模式(eager mode)和 PyTorch 的即时模式一样。与 TensorFlow 和 PyTorch 的即时模式类似,这种方式的执行速度相对较慢——即时模式更适合用作调试环境,而不是实际工作的方式。因此,让我们通过编译来加速我们的 train_step 函数。

当编写一个无状态的 JAX 函数时,可以通过 @jax.jit 装饰器将其编译为 XLA。在第一次执行时,它会进行追踪,在后续的执行中,就可以执行追踪得到的图(这与 @tf.function(jit_compile=True) 非常相似)。

@jax.jit
def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

经过前面的充分准备,现在准备训练模型了,训练循环本身很简单:我们只需反复调用 loss, state = train_step(state, data)

需要注意的是:

  • 我们将 tf.data.Dataset 产生的 TensorFlow 张量转换为 NumPy 数组后,再传递给我们的 JAX 函数。
  • 所有变量都必须在之前构建好:模型必须已经构建,优化器也必须已经构建。在这里,我们使用的是函数式 API 构建的模型,在初始化时它已经构建好了,但如果它是一个继承自某个基类的模型,就需要用一个批次的数据来调用它才能构建完成。
# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 156.4785
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.5526
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 1.8922
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2381
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.4812
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 2.3339
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.5615
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6471
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 1.6272
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.9416
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.8152
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.8838
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.1278
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 1.9234
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.3413
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2429
Seen so far: 48032 samples
3.1.4 底层处理指标

为了掌握底层处理的状况,可以向自定义的基础训练循环中添加指标监控。

在自定义编写的训练循环中重用内置的Keras指标(或你自己编写的自定义指标)可以按照以下流程来进行:

  1. 在循环开始时实例化指标
  2. train_step 的参数和 compute_loss_and_updates 参数中包含 metric_variables
  3. compute_loss_and_updates 函数中调用 metric.stateless_update_state()。这相当于 update_state(),只是无状态的。
  4. 当需要在 train_step 外部(在即时作用域中)显示指标的当前值时,将新的指标变量值附加到指标对象上,并调用 metric.result()
  5. 当需要清除指标的状态时(通常在每个周期结束时),调用 metric.reset_state()
# Get a fresh model
model = get_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()


def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (non_trainable_variables, metric_variables)


grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)


@jax.jit
def train_step(state, data):
    (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    ) = state
    x, y = data
    (loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
        trainable_variables, non_trainable_variables, metric_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    )

评估步骤函数如下

@jax.jit
def eval_step(state, data):
    trainable_variables, non_trainable_variables, metric_variables = state
    x, y = data
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = val_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        metric_variables,
    )

以下则是训练循环

# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
)

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, _, metric_variables = state
        for variable, value in zip(train_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Training accuracy: {train_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")

metric_variables = val_acc_metric.variables
(
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables

# Eval loop
for step, data in enumerate(val_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = eval_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, metric_variables = state
        for variable, value in zip(val_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Validation accuracy: {val_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
3.1.5 模型跟踪的底层损失处理

层和模型会递归地跟踪在前向传递过程中由调用 self.add_loss(value) 的层创建的任何损失。这些标量损失值的列表可以通过前向传递结束时的 model.losses 属性获得。

如果想要使用这些损失组件,则应该在训练步骤中将它们相加并添加到主要损失中。
以下示例考虑了层的影响,创建了一个活动正则化损失:

class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * jax.numpy.sum(inputs))
        return inputs

然后就可以使用如下这样的代码来调用这个活动正则化损失

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

现在,我们的 compute_loss_and_updates 函数应该像这样:

  • return_losses=True 传递给 model.stateless_call()
  • 将得到的损失相加,并将它们添加到主要损失中。
def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables, losses = model.stateless_call(
        trainable_variables, non_trainable_variables, x, return_losses=True
    )
    loss = loss_fn(y, y_pred)
    if losses:
        loss += jax.numpy.sum(losses)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, non_trainable_variables, metric_variables

3.2 在TensorFlow中从头开始编写训练和评估循环

GradientTape 范围内调用模型允许我们获取关于损失值的可训练层权重的梯度。使用优化器实例,可以使用这些梯度来更新这些变量(可以通过 model.trainable_weights 检索)。

以下是我们的训练循环的步骤:

  • 首先,打开一个循环,该循环遍历每个训练周期(epoch)
  • 对于每个周期,打开一个循环,该循环以批次的形式遍历数据集
  • 对于每个批次,打开一个 GradientTape() 范围
  • 在这个范围内,调用模型(前向传递)并计算损失
  • 在范围之外,则获取模型权重关于损失的梯度
  • 最后,使用优化器根据梯度来更新模型的权重
epochs = 3
for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:
            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply(grads, model.trainable_weights)

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")
3.2.1 底层处理指标

为了掌握底层操作的状况,可以在这个基础训练循环中添加指标监控。

按照以下流程,我们可以很容易地在自定义训练循环中重用内置指标(或你编写的自定义指标):

  • 在循环开始时实例化指标
  • 在每个批次后调用metric.update_state()
  • 当需要显示指标的当前值时,调用metric.result()
  • 当需要清除指标的状态时(通常在每个周期结束时),调用metric.reset_state()

让我们利用这些知识来在每个周期结束时计算训练和验证数据的SparseCategoricalAccuracy

# Get a fresh model
model = get_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

以下是训练和评估循环

epochs = 2
for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply(grads, model.trainable_weights)

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print(f"Training acc over epoch: {float(train_acc):.4f}")

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_state()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_state()
    print(f"Validation acc: {float(val_acc):.4f}")
    print(f"Time taken: {time.time() - start_time:.2f}s")
3.2.2 使用 tf.function 加速训练步骤

TensorFlow 的默认运行时是即时执行(eager execution)。因此,我们上面的训练循环是即时执行的。

这对于调试来说很好,但图编译在性能上有明显的优势。将计算描述为一个静态图使得框架能够应用全局性能优化。当框架被限制为一个接一个地执行操作时,而不知道接下来会发生什么,这是不可能的。

为此可以将任何以张量作为输入的函数编译成静态图。只需在该函数上添加一个 @tf.function 装饰器:

tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply(grads, model.trainable_weights)
    train_acc_metric.update_state(y, logits)
    return loss_value

评估模型可以这样做

@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)

如此,就可以使用编译后的训练步骤重新运行训练循环:

epochs = 2
for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print(f"Training acc over epoch: {float(train_acc):.4f}")

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_state()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_state()
    print(f"Validation acc: {float(val_acc):.4f}")
    print(f"Time taken: {time.time() - start_time:.2f}s")
3.2.3 模型跟踪底层操作的损失

模型在正向传递过程中会递归地跟踪由层创建的任何损失,这些层通过调用 self.add_loss(value) 来实现。正向传递结束后,可以通过 model.losses 属性访问得到的标量损失值列表。

如果想要使用这些损失组件,应该在训练步骤中将它们相加,并添加到主要损失中。

考虑到上述损失,创建了一个活动正则化损失:

class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * tf.reduce_sum(inputs))
        return inputs

建立模型时可以参照以下方法使用正则化损失

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

训练过程加上以上正则化损失

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        # Add any extra losses created during the forward pass.
        loss_value += sum(model.losses)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply(grads, model.trainable_weights)
    train_acc_metric.update_state(y, logits)
    return loss_value
3.2.4 完整的TensorFlow编写训练和评估循环示例

本节以GAN来示例编写的过程。生成对抗网络(GANs)通过学习图像训练数据集(图像的“潜在空间”)的潜在分布,可以生成几乎逼真的新图像。

GAN由两部分组成:一个“生成器”模型,它将潜在空间中的点映射到图像空间中的点;一个“判别器”模型,它是一个分类器,可以区分真实图像(来自训练数据集)和假图像(生成器网络的输出)。

GAN训练循环大致如下:

  1. 训练判别器。- 在潜在空间中抽取一批随机点。- 通过“生成器”模型将这些点转换为假图像。- 获取一批真实图像,并将它们与生成的图像组合在一起。- 训练“判别器”模型来区分生成的图像和真实的图像。

  2. 训练生成器。- 在潜在空间中抽取随机点。- 通过“生成器”网络将这些点转换为假图像。- 获取一批真实图像,并将它们与生成的图像组合在一起。- 训练“生成器”模型来“欺骗”判别器,将假图像分类为真实图像。

关于GANs如何工作的更详细概述,请参阅《Python深度学习》。

现在让我们实现这个训练循环。首先,创建一个用于分类假数字与真数字的判别器:

discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.GlobalMaxPooling2D(),
        keras.layers.Dense(1),
    ],
    name="discriminator",
)
discriminator.summary()

接下来,创建一个生成器网络,它将潜在向量转换为形状为 (28, 28, 1) 的输出(代表MNIST数字):

latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        keras.layers.Dense(7 * 7 * 128),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Reshape((7, 7, 128)),
        keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

以下是循环训练的代码

# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)

# Instantiate a loss function.
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)


@tf.function
def train_step(real_images):
    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Decode them to fake images
    generated_images = generator(random_latent_vectors)
    # Combine them with real images
    combined_images = tf.concat([generated_images, real_images], axis=0)

    # Assemble labels discriminating real from fake images
    labels = tf.concat(
        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
    )
    # Add random noise to the labels - important trick!
    labels += 0.05 * tf.random.uniform(labels.shape)

    # Train the discriminator
    with tf.GradientTape() as tape:
        predictions = discriminator(combined_images)
        d_loss = loss_fn(labels, predictions)
    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    d_optimizer.apply(grads, discriminator.trainable_weights)

    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Assemble labels that say "all real images"
    misleading_labels = tf.zeros((batch_size, 1))

    # Train the generator (note that we should *not* update the weights
    # of the discriminator)!
    with tf.GradientTape() as tape:
        predictions = discriminator(generator(random_latent_vectors))
        g_loss = loss_fn(misleading_labels, predictions)
    grads = tape.gradient(g_loss, generator.trainable_weights)
    g_optimizer.apply(grads, generator.trainable_weights)
    return d_loss, g_loss, generated_images

通过反复在图像批次上调用train_step来训练GAN。

# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

epochs = 1  # In practice you need at least 20 epochs to generate nice digits.
save_dir = "./"

for epoch in range(epochs):
    print(f"\nStart epoch {epoch}")

    for step, real_images in enumerate(dataset):
        # Train the discriminator & generator on one batch of real images.
        d_loss, g_loss, generated_images = train_step(real_images)

        # Logging.
        if step % 100 == 0:
            # Print metrics
            print(f"discriminator loss at step {step}: {d_loss:.2f}")
            print(f"adversarial loss at step {step}: {g_loss:.2f}")

            # Save one generated image
            img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)
            img.save(os.path.join(save_dir, f"generated_img_{step}.png"))

        # To limit execution time we stop after 10 steps.
        # Remove the lines below to actually train the model!
        if step > 10:
            break

3.3 在Pytorch编写自定义训练循环

3.3.1 定义优化器和损失函数
# Instantiate a torch optimizer
model = get_model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Instantiate a torch loss function
loss_fn = torch.nn.CrossEntropyLoss()
3.3.2使用小批量梯度下降和自定义训练循环

下面的示例中使用小批量梯度下降和自定义训练循环来训练我们的模型,在一个损失张量上调用loss.backward()会触发反向传播。一旦完成,优化器就会神奇地知道每个变量的梯度,并可以更新其变量,这是通过optimizer.step()完成的。张量、变量、优化器之间都通过隐藏的全局状态相互连接。另外,不要忘记在loss.backward()之前调用model.zero_grad(),否则你将无法为你的变量获得正确的梯度。

以下是我们的训练循环步骤:

  • 打开一个循环,该循环遍历每个epoch
  • 对于每个epoch,我们打开一个循环,该循环在数据集的批次上进行迭代
  • 对于每个批次,我们调用模型以获取输入数据的预测,然后使用这些预测来计算损失值
  • 我们调用loss.backward()
  • 在作用域之外,获取模型权重相对于损失的梯度
  • 最后,使用优化器基于梯度来更新模型的权重
epochs = 3
for epoch in range(epochs):
    for step, (inputs, targets) in enumerate(train_dataloader):
        # Forward pass
        logits = model(inputs)
        loss = loss_fn(logits, targets)

        # Backward pass
        model.zero_grad()
        loss.backward()

        # Optimizer variable updates
        optimizer.step()

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

当我们使用Keras优化器和Keras损失函数时,训练循环看起来会有所不同。以下是重要的差异:

  1. 通过调用每个可训练变量上的v.grad(注意在Keras中没有v.value.grad,因为Keras使用的是TensorFlow的底层API,其中变量梯度通过tf.GradientTape()获取)来获取变量的梯度。

  2. 通过optimizer.apply_gradients()来更新变量(注意这是TensorFlow/Keras的方式,而不是PyTorch的optimizer.apply())。并且这通常在TensorFlow的tf.GradientTape()的上下文中进行,而不是PyTorch的torch.no_grad()

一个重要的注意点:虽然所有NumPy、TensorFlow、JAX、Keras API以及Python的unittest API都使用fn(y_true, y_pred)(先真实值,后预测值)的参数顺序约定,但PyTorch实际上对其损失函数使用fn(y_pred, y_true)的顺序。因此,请确保对logits和targets的顺序进行反转。

model = get_model()
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")
    for step, (inputs, targets) in enumerate(train_dataloader):
        # Forward pass
        logits = model(inputs)
        loss = loss_fn(targets, logits)

        # Backward pass
        model.zero_grad()
        trainable_weights = [v for v in model.trainable_weights]

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            optimizer.apply(gradients, trainable_weights)

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")
3.3.3 Pytorch中的底层处理指示

为监控运行过程的状态,可以在这个基本的训练循环中添加指标监控。
这样就可以很方便地在这样从头开始编写的训练循环中重用Keras内置的指标(或你自己编写的自定义指标)。以下是流程:

  • 在循环开始时实例化指标
  • 在每个批次后调用 metric.update_state()
  • 当需要显示指标的当前值时,调用 metric.result()
  • 当需要清除指标的状态时(通常在每个epoch结束时),调用 metric.reset_state()

以下示例利用上述这些知识,在每个epoch结束时在训练数据和验证数据上计算分类准确率(CategoricalAccuracy):

# Get a fresh model
model = get_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()

训练和评估循环如下

for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")
    for step, (inputs, targets) in enumerate(train_dataloader):
        # Forward pass
        logits = model(inputs)
        loss = loss_fn(targets, logits)

        # Backward pass
        model.zero_grad()
        trainable_weights = [v for v in model.trainable_weights]

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            optimizer.apply(gradients, trainable_weights)

        # Update training metric.
        train_acc_metric.update_state(targets, logits)

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print(f"Training acc over epoch: {float(train_acc):.4f}")

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_state()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataloader:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_state()
    print(f"Validation acc: {float(val_acc):.4f}")
3.3.4 模型对底层损失的跟踪

模型跟踪的低级别处理损失
层和模型递归地跟踪在正向传递过程中由调用 self.add_loss(value) 的层创建的任何损失。正向传递结束后,这些标量损失值的列表可以通过属性 model.losses 获得。

如果你想要使用这些损失组件,你应该在训练步骤中将它们相加并添加到主要损失中。

考虑这样一个层,它创建一个活动正则化损失:

class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * torch.sum(inputs))
        return inputs

建立模型使用正则化损失

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

如此,训练和评估循环变化为以下样式

# Get a fresh model
model = get_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()

for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")
    for step, (inputs, targets) in enumerate(train_dataloader):
        # Forward pass
        logits = model(inputs)
        loss = loss_fn(targets, logits)
        if model.losses:
            loss = loss + torch.sum(*model.losses)

        # Backward pass
        model.zero_grad()
        trainable_weights = [v for v in model.trainable_weights]

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            optimizer.apply(gradients, trainable_weights)

        # Update training metric.
        train_acc_metric.update_state(targets, logits)

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print(f"Training acc over epoch: {float(train_acc):.4f}")

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_state()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataloader:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_state()
    print(f"Validation acc: {float(val_acc):.4f}")

4、总结

编写自定义训练和评估循环是深度学习中的一个重要步骤,它允许研究人员和开发者对模型训练过程进行更精细的控制。以下是关于编写自定义训练和评估循环的总结:

4.1 灵活性

  • 控制粒度:自定义循环允许用户更精细地控制模型的训练过程,包括每个epoch、每个batch的处理,以及数据的加载和预处理。
  • 自定义损失函数:可以轻松实现并集成自定义损失函数,以应对特定任务的需求。
  • 优化器选择:可以选择最适合任务需求的优化器,并调整其参数。

4.2. 可扩展性

  • 添加监控指标:除了常规的损失和准确率外,还可以添加自定义的监控指标,以评估模型的性能。
  • 集成额外功能:可以集成额外的功能,如模型检查点保存、学习率调整、早停策略等。

4.3. 调试与可视化

  • 易于调试:由于自定义循环的代码相对简单明了,因此更容易进行调试和错误追踪。
  • 可视化:可以使用TensorBoard等工具对训练过程进行可视化,包括损失曲线、准确率曲线等。

4.4. 性能优化

  • 内存管理:自定义循环允许更精细的内存管理,如通过手动释放不再需要的变量来减少内存占用。
  • 并行与分布式训练:在需要处理大量数据或加速训练时,可以使用并行和分布式训练技术,这些技术通常在自定义循环中更容易实现。

4.5. 注意事项

  • 数据加载:确保数据加载正确,并且数据预处理步骤与模型输入相匹配。
  • 梯度累积:在GPU内存有限的情况下,可以使用梯度累积技术来模拟更大的batch size。
  • 学习率调整:根据训练情况调整学习率,以优化训练效果。
  • 异常处理:在自定义循环中添加异常处理机制,以确保在出现错误时能够正确关闭模型和资源。

编写自定义训练和评估循环是一个具有挑战性的任务,但它也带来了很大的灵活性、可扩展性和性能优化机会。通过深入理解模型的训练过程并编写自定义循环,可以实现对模型性能的精细控制,并探索新的优化策略。

  • 22
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值