方法调用unknown_Tensorflow 2.x自定义网络层的方法及其注意事项

一、概述

在TensorFlow 2.0 中,tf.keras 是推荐使用的默认高级 API,这就为我们使用统一的API构建搞效的神经网络带来极大的便捷,但是tensorflow 2.0是默认你就是懂用keras的,对于keras不是很熟悉的人难免会踩到一些坑。这不,我就在开发MobileNetV3的过程中遇到了不少问题,好在都被我一一解决了!其中,在现有的各种纷繁复杂模型算法里已不再仅仅调用基础的API就可以构建了,有时我们想像搭积木一样构建一个更清晰抽象的网络层,有时我们需要根据输入动态的构建权重,有时我们需要在特定网络层内做一些复杂的运算,有时我们也希望把一些统一的算子封装成一个具体的网络层比如BottleNet,这时候我们就会需要进行自定义网络层了,那么我们该怎么使用tf.keras自定义网络层及有哪些注意事项呢?本篇文章,我想分享一下tensorflow 2.x自定义网络层的方法及其注意事项!

二、方法

在tensorflow 2.x中自定义网络层最好的方法就是继承tf.keras.layers.Layer类,并且根据实际情况重写如下几个类方法:

*__init__:初始化类,你可以在此配置一些网络层需要的参数,并且也可以在此实例化tf.keras提供的一些基础算子比如DepthwiseConv2D,方便在call方法中应用;

*build:该方法可以获取输入的shape维度,方便动态构建某些需要的算子比如Pool或基于input shape构建权重;

*call: 网络层进行前向推理的实现方法;

一般常见的自定义网络层如下,其中build方法不是必需的,大部分情况下都可以省略:

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs, name="MyLayer", **kwargs):
        super(MyLayer, self).__init__(name=name, **kwargs)
        self.num_outputs = num_outputs

    def build(self, input_shape):
        self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
        super().build(input_shape)

    def call(self, input):
        output = tf.matmul(input, self.kernel)
        return output

其中,当我们在初始化该自定义网络时,需要明确告知num_outputs参数,而其他参数则不是必须的,另外,当该自定义网络组成的模型开始构建后,其就会去调用build方法去构建一个self.kernel权重,而当模型进行前向推理时,则会调用call方法进行计算!

调用自定义网络层的方法如下:

mylayer = MyLayer(num_outputs=1000, name="MyLayer", training=True)

三、注意事项

注意一:

理论上来说,一般我们只要继承自tf.keras.layers.Layer并且实现好* __init__ 以及 * call 两个接口就可以了,但是当我们在使用model.save()方法保存整个模型架构及其权重为hdf5格式时,我们却会遇到如下问题:

NotImplementedError: Layers with arguments in `__init__` must override `get_config`.

这个问题是因为我们没有在自定义网络层时重写get_config导致的!那我们该怎么去实现该方法呢?

我们主要看传入__init__接口时有哪些配置参数,然后在get_config内一一的将它们转为字典键值并且返回使用,以Mylayer为例:

def get_config(self):
    config = {"num_outputs":self.num_outputs}
    base_config = super(Mylayer, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

一般来说,父类的config也是需要一并保存的,其中base_config即是父类网络层实现的配置参数,最后把父类及继承类的config组装为字典形式即可解决该问题!

注意二:

当我们自定义网络层并且有效保存模型后,希望使用tf.keras.models.load_model进行模型加载时,可能会报如下错误:

raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
ValueError: Unknown layer: Mylayer

解决该问题的方法如下:

首先,建立一个字典,该字典的键是自定义网络层时设定该层的名字,其值为该自定义网络层的类名,该字典将用于加载模型时使用!

_custom_objects = {
    "Mylayer" :  Mylayer,
}

然后,在tf.keras.models.load_model内传入custom_objects告知如何解析重建自定义网络层,其方法如下:

model = tf.keras.models.load_model("path/to/your/model", custom_objects=_custom_objects)

注意三:

当我们自定义一个网络层其名字与默认的tf.keras网络层一样时,可能会报出一些奇怪的问题,其实是因为重名了,比如当我们定义如下网络层时:

class Dropout(tf.keras.layers.Layer):
    def __init__(self, dropout_rate, name="Dropout", **kwargs):
        super(Dropout, self).__init__(name=name, **kwargs)
        self.dropout_rate = dropout_rate
        self.dropout = tf.keras.layers.Dropout(rate=dropout_rate, name=f'Dropout', **kwargs)

    def call(self, input):
        output = self.dropout(input)
        return output

    def get_config(self):
        config = {"dropout_rate":self.dropout_rate}
        base_config = super(Dropout, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

因为,默认的tf.keras内建网络层已经有了一个同名的tf.keras.layers.Dropout实现,名字都为“Dropout”,当我们在保存模型时没有问题,但当我们加载模型时就会报如下错误:

TypeError: __init__() missing 1 required positional argument: 'rate'

所以,只要我们换用另一个自定义网络层名字即可解决该问题,而如果不注意这一点,那么死活也找不到原因,我也是花了很长时间才发现了这个问题!

注意四:

我们在实现自定义网络层时,最好统一在初始化时传入可变参数**kwargs,这是因为在model推理时,有时我们需要对所有构成该模型的网络层进行统一的传参,比如传入training标志位用以告知所有网络层当前所属的计算状态,如果某个自定义网络层未有传入**kwargs并且传给其使用到的一些基础组件,那么就会产生一些意想不到的问题,比如存在执行logit = model(input, training=True)前向推理时,保存模型无误,但是加载模型时则会报如下错误!当然避免该问题的方法就是不要传入training参数,但是这个training有时又是不可缺少的,他可以保证模型正常的训练并且进行测试推理,这时候就需要我们注意了!

TypeError: __init__() got an unexpected keyword argument 'trainable'

彩蛋:

一个高效简捷基于tensorflow 2.x实现的MobileNetV3已实作测试好,欢迎给个Star鼓励!

Many thanks!

sirius-ai/MobileNetV3-TF​github.com
c7746eb054a996ff64771f1057c8b086.png
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值