Keras深度学习框架第五讲:在JAX中自定义fit()函数中的操作

80 篇文章 0 订阅
50 篇文章 4 订阅

##1、 绪论

在进行有监督学习时,使用fit()函数通常可以顺利地进行。

但当需要控制每一个小细节时,可以完全从头开始编写自己的训练循环。

但如果需要一个自定义的训练算法,同时又想利用fit()的便捷功能,如回调(callbacks)、内置的分布支持(built-in distribution support)或步骤融合(step fusing)等,又该如何呢?

Keras的一个核心原则是复杂性的逐步揭示,总是能够以一种渐进的方式进入更低级别的工作流程。如果高级功能不完全符合你的测试用例,也不应该一下子陷入困境。而应该能够在获得更多对小细节的控制权的同时,保留相应级别的高级便利。

当需要自定义fit()函数的行为时,应该重写Model类的训练步骤函数。这个函数在fit()为每一批数据调用时执行。之后,就可以像往常一样调用fit()——而它将会运行重新编写后的包含作者思想的学习算法。

请注意,这种模式并不阻止编程者使用函数式API构建模型。无论是构建Sequential模型、函数式API模型还是子类化模型,都可以这样做。

2、准备工作

2.1 设置

按照以下的代码试样进行初始化设置

import os

# This guide can only be run with the JAX backend.
os.environ["KERAS_BACKEND"] = "jax"

import jax
import keras
import numpy as np

#2.2 使用的示例

这一节我们以一个简单的例子开始感受自定义fit()函数的方法:

首先创建一个新的类,该类继承自keras.Model
这样就实现了一个完全无状态的compute_loss_and_updates()方法,用于计算损失以及模型非可训练变量的更新值。在内部,它调用了stateless_call()和内置的compute_loss()
同时还实现了一个完全无状态的train_step()方法,用于计算当前指标值(包括损失)以及可训练变量、优化器变量和指标变量的更新值。
请注意,你还可以通过以下方式考虑sample_weight参数:

  • 解包数据为x, y, sample_weight = data
  • sample_weight传递给compute_loss()
  • stateless_update_state()中将sample_weightyy_pred一起传递给指标
class CustomModel(keras.Model):
    def compute_loss_and_updates(
        self,
        trainable_variables,
        non_trainable_variables,
        x,
        y,
        training=False,
    ):
        y_pred, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            x,
            training=training,
        )
        loss = self.compute_loss(x, y, y_pred)
        return loss, (y_pred, non_trainable_variables)

    def train_step(self, state, data):
        (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            metrics_variables,
        ) = state
        x, y = data

        # Get the gradient function.
        grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)

        # Compute the gradients.
        (loss, (y_pred, non_trainable_variables)), grads = grad_fn(
            trainable_variables,
            non_trainable_variables,
            x,
            y,
            training=True,
        )

        # Update trainable variables and optimizer variables.
        (
            trainable_variables,
            optimizer_variables,
        ) = self.optimizer.stateless_apply(
            optimizer_variables, grads, trainable_variables
        )

        # Update metrics.
        new_metrics_vars = []
        logs = {}
        for metric in self.metrics:
            this_metric_vars = metrics_variables[
                len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
            ]
            if metric.name == "loss":
                this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
            else:
                this_metric_vars = metric.stateless_update_state(
                    this_metric_vars, y, y_pred
                )
            logs[metric.name] = metric.stateless_result(this_metric_vars)
            new_metrics_vars += this_metric_vars

        # Return metric logs and updated state variables.
        state = (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            new_metrics_vars,
        )
        return logs, state

通过调用运行代码,看看运行的结果:

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - mae: 1.0022 - loss: 1.2464
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 198us/step - mae: 0.5811 - loss: 0.4912
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 231us/step - mae: 0.4386 - loss: 0.2905

<keras.src.callbacks.history.History at 0x14da599c0>

3、深入了解自定义fit()的内部运行机制

#3.1 train_step传递损失

当然,你可以在compile()中省略传递损失函数,而是在train_step中手动完成所有操作。对于指标也是如此。

下面是一个更低层次的示例,它只使用compile()来配置优化器:

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.loss_fn = keras.losses.MeanSquaredError()

    def compute_loss_and_updates(
        self,
        trainable_variables,
        non_trainable_variables,
        x,
        y,
        training=False,
    ):
        y_pred, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            x,
            training=training,
        )
        loss = self.loss_fn(y, y_pred)
        return loss, (y_pred, non_trainable_variables)

    def train_step(self, state, data):
        (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            metrics_variables,
        ) = state
        x, y = data

        # Get the gradient function.
        grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)

        # Compute the gradients.
        (loss, (y_pred, non_trainable_variables)), grads = grad_fn(
            trainable_variables,
            non_trainable_variables,
            x,
            y,
            training=True,
        )

        # Update trainable variables and optimizer variables.
        (
            trainable_variables,
            optimizer_variables,
        ) = self.optimizer.stateless_apply(
            optimizer_variables, grads, trainable_variables
        )

        # Update metrics.
        loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]
        mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]

        loss_tracker_vars = self.loss_tracker.stateless_update_state(
            loss_tracker_vars, loss
        )
        mae_metric_vars = self.mae_metric.stateless_update_state(
            mae_metric_vars, y, y_pred
        )

        logs = {}
        logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(
            loss_tracker_vars
        )
        logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)

        new_metrics_vars = loss_tracker_vars + mae_metric_vars

        # Return metric logs and updated state variables.
        state = (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            new_metrics_vars,
        )
        return logs, state

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        return [self.loss_tracker, self.mae_metric]


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# We don't pass a loss or metrics here.
model.compile(optimizer="adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.6085 - mae: 0.6580
Epoch 2/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 215us/step - loss: 0.2630 - mae: 0.4141
Epoch 3/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 202us/step - loss: 0.2271 - mae: 0.3835
Epoch 4/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 192us/step - loss: 0.2093 - mae: 0.3714
Epoch 5/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 194us/step - loss: 0.2188 - mae: 0.3818

<keras.src.callbacks.history.History at 0x14de01420>

#3.2 重写test_step

如果你想要在调用model.evaluate()时做类似的事情,那么你需要以完全相同的方式重写test_step。它看起来像这样:

class CustomModel(keras.Model):
    def test_step(self, state, data):
        # Unpack the data.
        x, y = data
        (
            trainable_variables,
            non_trainable_variables,
            metrics_variables,
        ) = state

        # Compute predictions and loss.
        y_pred, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            x,
            training=False,
        )
        loss = self.compute_loss(x, y, y_pred)

        # Update metrics.
        new_metrics_vars = []
        for metric in self.metrics:
            this_metric_vars = metrics_variables[
                len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
            ]
            if metric.name == "loss":
                this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
            else:
                this_metric_vars = metric.stateless_update_state(
                    this_metric_vars, y, y_pred
                )
            logs = metric.stateless_result(this_metric_vars)
            new_metrics_vars += this_metric_vars

        # Return metric logs and updated state variables.
        state = (
            trainable_variables,
            non_trainable_variables,
            new_metrics_vars,
        )
        return logs, state


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 973us/step - mae: 0.7887 - loss: 0.8385

[0.8385222554206848, 0.7956181168556213]

4、总结

在JAX中自定义类似于Keras的fit()函数的行为,你需要从头开始实现整个训练循环,因为JAX并不像Keras那样提供高层级的训练和评估API。JAX主要是一个用于自动微分和硬件加速的库,它鼓励你以更函数式的方式编写代码,而不是使用像Keras那样的面向对象的方法。

要在JAX中自定义类似fit()函数的行为,你需要:

  1. 定义模型:首先,你需要定义你的模型架构。在JAX中,这通常是通过使用JAX的jit(即时编译)和vmap(向量映射)等工具,结合NumPy-like的API来完成的。模型可能是一组函数,用于执行前向传递和计算损失。

  2. 定义损失函数:接下来,你需要定义一个损失函数,该函数接受模型预测和真实标签作为输入,并返回一个标量损失值。

  3. 定义优化器:在JAX中,你需要选择或实现一个优化器。这通常是一个函数,它接受梯度、参数和可选的学习率,并返回更新后的参数。

  4. 编写训练循环:现在,你可以编写一个训练循环。这个循环会重复以下步骤:

    a. 对一批数据进行前向传递,得到预测值。

    b. 计算损失。

    c. 使用JAX的自动微分工具(如grad)计算梯度。

    d. 使用优化器更新模型参数。

    e. 跟踪训练指标(如损失、准确率等)。

    f. 在验证集上评估模型(如果需要)。

    g. 根据需要保存模型检查点。

  5. 处理数据:你还需要一个数据加载和批处理机制。这可以是一个简单的循环,用于迭代数据集中的每个样本或批次,也可以是一个更复杂的数据加载器,它使用诸如jax.process_mapjax.pmap等工具进行并行处理。

  6. 日志记录和监控:你可能还想记录训练过程中的指标,并使用诸如TensorBoard或其他可视化工具进行监控。在JAX中,你需要自己实现这些功能,或者使用与JAX兼容的第三方库。

需要注意的是,JAX的设计哲学是鼓励用户以更底层、更灵活的方式编写代码。因此,虽然这可能会增加一些初始的复杂性,但它也提供了更大的灵活性和控制权。在JAX中自定义fit()函数的行为是一个很好的机会,可以深入了解深度学习训练的底层机制,并根据你的具体需求进行定制。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值