转载https://zhuanlan.zhihu.com/p/86886620
在tensorflow 2.x中自定义网络层最好的方法就是继承tf.keras.layers.Layer类,并且根据实际情况重写如下几个类方法:
__init__:
初始化类,你可以在此配置一些网络层需要的参数,并且也可以在此实例化tf.keras提供的一些基础算子比如DepthwiseConv2D,方便在call方法中应用;你可以在其中执行所有与输入无关的初始化。
build:
可以获得输入张量的形状,并可以进行其余的初始化。该方法可以获取输入的shape维度,方便动态构建某些需要的算子比如Pool或基于input shape构建权重;
call:
网络层进行前向推理的实现方法;构建网络结构,进行前向传播。
一般常见的自定义网络层如下,其中build方法不是必需的,大部分情况下都可以省略:
class MyLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs, name="MyLayer", **kwargs):
super(MyLayer, self).__init__(name=name, **kwargs)
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
super().build(input_shape)
def call(self, input):
output = tf.matmul(input, self.kernel)
return output
其中,当我们在初始化该自定义网络时,需要明确告知num_outputs参数,而其他参数则不是必须的,另外,当该自定义网络组成的模型开始构建后,其就会去调用build方法去构建一个self.kernel权重,而当模型进行前向推理时,则会调用call方法进行计算!add_variable
将被替换为add_weight
方法。
调用自定义网络层的方法如下:
mylayer = MyLayer(num_outputs=1000, name="MyLayer", training=True)
三、注意事项
理论上来说,一般我们只要继承自tf.keras.layers.Layer并且实现好* init 以及 * call 两个接口就可以了,但是当我们在使用model.save()方法保存整个模型架构及其权重为hdf5格式时,我们却会遇到如下问题:
NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
## 还有关于保存和加载模型的问题没有粘贴过来。