fit函数 model_TF2 Keras (7) : 定制化 Model.fit

本文是对官方文档 的学习笔记。

首先说了一下 Keras 的一个设计理念: 提供阶梯型的复杂性。 一开始新手接触的是封装的很好的接口, 好处是简单, 缺点是缺乏灵活性, 无法发挥自己的想象力。 当开发者需要修改底层机制的时候, Keras 争取让开发者逐渐的面对low level 代码的复杂性, 而不是一下子将所有的细节都暴露给开发者(然后把他们吓跑)。

A core principle of Keras is progressive disclosure of complexity. You should always be able to get into lower-level workflows in a gradual way. You shouldn't fall off a cliff if the high-level functionality doesn't exactly match your use case. You should be able to gain more control over the small details while retaining a commensurate amount of high-level convenience.

如果开发者想控制 fit 过程, 那么他们可以通过 override Model 类的 fit 函数来实现。

第一个例子

这里例子中, 重载了train_step 函数, 在该函数中使用 self.compiled_loss 来计算 loss. 然后 使用 self.compiled_metrics.update_state(y, y_pred) 来更新 metrics。

class CustomModel(keras.Model):

def train_step(self, data):

# Unpack the data. Its structure depends on your model and

# on what you pass to `fit()`.

x, y = data

with tf.GradientTape() as tape:

y_pred = self(x, training=True) # Forward pass

# Compute the loss value

# (the loss function is configured in `compile()`)

loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

# Compute gradients

trainable_vars = self.trainable_variables

gradients = tape.gradient(loss, trainable_vars)

# Update weights

self.optimizer.apply_gradients(zip(gradients, trainable_vars))

# Update metrics (includes the metric that tracks the loss)

self.compiled_metrics.update_state(y, y_pred)

# Return a dict mapping metric names to current value

return {m.name: m.result() for m in self.metrics}

使用这个Model

import numpy as np

# Construct and compile an instance of CustomModel

inputs = keras.Input(shape=(32,))

outputs = keras.layers.Dense(1)(inputs)

model = CustomModel(inputs, outputs)

model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual

x = np.random.random((1000, 32))

y = np.random.random((1000, 1))

model.fit(x, y, epochs=3)

更多的细节

其实也可以不再 Compile 中填写 loss Fucntion, 而选择吧所有事情都放在train_step 中来做。

下面是一个更加 low level的例子, Co

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值