Tensorflow 2.x自定义网络层的方法及其注意事项

转载https://zhuanlan.zhihu.com/p/86886620

在tensorflow 2.x中自定义网络层最好的方法就是继承tf.keras.layers.Layer类,并且根据实际情况重写如下几个类方法:
__init__:初始化类,你可以在此配置一些网络层需要的参数,并且也可以在此实例化tf.keras提供的一些基础算子比如DepthwiseConv2D,方便在call方法中应用;你可以在其中执行所有与输入无关的初始化。
build:可以获得输入张量的形状,并可以进行其余的初始化。该方法可以获取输入的shape维度,方便动态构建某些需要的算子比如Pool或基于input shape构建权重;
call: 网络层进行前向推理的实现方法;构建网络结构,进行前向传播。
一般常见的自定义网络层如下,其中build方法不是必需的,大部分情况下都可以省略:

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs, name="MyLayer", **kwargs):
        super(MyLayer, self).__init__(name=name, **kwargs)
        self.num_outputs = num_outputs

    def build(self, input_shape):
        self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
        super().build(input_shape)

    def call(self, input):
        output = tf.matmul(input, self.kernel)
        return output

其中,当我们在初始化该自定义网络时,需要明确告知num_outputs参数,而其他参数则不是必须的,另外,当该自定义网络组成的模型开始构建后,其就会去调用build方法去构建一个self.kernel权重,而当模型进行前向推理时,则会调用call方法进行计算!add_variable将被替换为add_weight方法。
调用自定义网络层的方法如下:

mylayer = MyLayer(num_outputs=1000, name="MyLayer", training=True)

三、注意事项
理论上来说,一般我们只要继承自tf.keras.layers.Layer并且实现好* init 以及 * call 两个接口就可以了,但是当我们在使用model.save()方法保存整个模型架构及其权重为hdf5格式时,我们却会遇到如下问题:

NotImplementedError: Layers with arguments in `__init__` must override `get_config`.

## 还有关于保存和加载模型的问题没有粘贴过来。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值