本文是对官方文档 的学习笔记。
首先说了一下 Keras 的一个设计理念: 提供阶梯型的复杂性。 一开始新手接触的是封装的很好的接口, 好处是简单, 缺点是缺乏灵活性, 无法发挥自己的想象力。 当开发者需要修改底层机制的时候, Keras 争取让开发者逐渐的面对low level 代码的复杂性, 而不是一下子将所有的细节都暴露给开发者(然后把他们吓跑)。
A core principle of Keras is progressive disclosure of complexity. You should always be able to get into lower-level workflows in a gradual way. You shouldn't fall off a cliff if the high-level functionality doesn't exactly match your use case. You should be able to gain more control over the small details while retaining a commensurate amount of high-level convenience.
如果开发者想控制 fit 过程, 那么他们可以通过 override Model 类的 fit 函数来实现。
第一个例子
这里例子中, 重载了train_step 函数, 在该函数中使用 self.compiled_loss 来计算 loss. 然后 使用 self.compiled_metrics.update_state(y, y_pred) 来更新 metrics。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
使用这个Model
import numpy as np
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
更多的细节
其实也可以不再 Compile 中填写 loss Fucntion, 而选择吧所有事情都放在train_step 中来做。
下面是一个更加 low level的例子, Co