前提:
在将模型保存为HD5F时。
动机:
- 如果不重写get_config,将无法在Tensorboard中载入模型图(model graph)
- 无法使用
model.save
保存模型
The base class get_config method actually refuses to run if the subclass initializer has positional arguments;
做法
在自定义层中,重写get_config方法,将位置参数以字典方式传入。
例子:
class Linear(keras.layers.Layer):
def __init__(self, units=32, **kwargs):
super(Linear, self).__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
def get_config(self):
config = super(Linear, self).get_config()
config.update({"units": self.units})
return config
layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
参考: