keras实现自定义层的关键步骤解析

1537250305768859.png

1537520241840992.png 

前言:Keras提供众多常见的已编写好的层对象,例如常见的卷积层、池化层等,我们可以直接通过以下代码调用。Keras中的层大致上分为两种类型:

第一种是带有训练参数的:比如Dense层、Conv2D层,等等,我们在训练的过程中需要训练层的权重和偏置项;

第二种是不带训练参数的:比如dropout层、flatten层、等等,我们不需要训练它的权重,只需要对输入进行加工处理再输出就行了。

我们在实际应用中,我们经常需要自己构建一些层对象,已满足某些自定义网络的特殊需求。,也无非就是上面两种,一种是带有参数的,一种是不带参数的,不管是哪一种,幸运的是,Keras对自定义层都提供了良好的支持。

一、Lambda层

二、自定义层

    2.1 Dense层解析

    2.2 三个核心方法的解析

        (1)build方法

        (2)call方法

        (3)compute_output_shape方法

    2.3 基类Layer中的定义

一、Lambda层

对于简单、无状态的自定义操作,你也许可以通过 layers.core.Lambda 层来实现。即使用keras.core.lambda()
如果我们的自定义层中不包含可训练的权重,而只是对上一层输出做一些函数变换,那么我们可以直接使用keras.core模块(该模块包含常见的基础层,如Dense、Activation等)下的lambda函数:

keras.layers.core.Lambda(function, output_shape=None, mask=None, arguments=None)

参数说明: 
function:要实现的函数,该函数仅接受一个变量,即上一层的输出 
output_shape:函数应该返回的值的shape,可以是一个tuple,也可以是一个根据输入shape计算输出shape的函数 
mask: 掩膜 
arguments:可选,字典,用来记录向函数中传递的其他关键字参数

注意:也不是说对于没有训练参数的层就一定要用Lambda层,我也可以使用自定义的层,只不过是没必要那么复杂而已。

二、自定义层

但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层。在这种情况下,我们需要定义的是一个全新的、拥有可训练权重的层,这个时候我们就需要使用下面的方法。即通过编写自定义层,从Layer中继承。

下面的内容来自官方文档,有部分是自己的注解,后面着重总结的

这是一个 Keras2.0 中,Keras 层的骨架(如果你用的是旧的版本,请更新到新版)。你只需要实现三个方法即可:

要定制自己的层,需要实现下面三个方法

  • build(input_shape):这是定义权重的方法,可训练的权应该在这里被加入列表self.trainable_weights中。其他的属性还包括self.non_trainabe_weights(列表)和self.updates(需要更新的形如(tensor,new_tensor)的tuple的列表)。这个方法必须设置self.built = True,可通过调用super([layer],self).build()实现。
  • call(x):这是定义层功能的方法,除非你希望你写的层支持masking,否则你只需要关心call的第一个参数:输入张量。
  • compute_output_shape(input_shape):如果你的层修改了输入数据的shape,你应该在这里指定shape变化的方法,这个函数使得Keras可以做自动shape推断。
     
from keras import backend as K
from keras.engine.topology import Layer

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # 为该层创建一个可训练的权重
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # 一定要在最后调用它

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

还可以定义具有多个输入张量和多个输出张量的 Keras 层。 为此,你应该假设方法 build(input_shape)call(x) 和 compute_output_shape(input_shape) 的输入输出都是列表。 这里是一个例子,与上面那个相似:

from keras import backend as K
from keras.engine.topology import Layer

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        # 为该层创建一个可训练的权重
        self.kernel = self.add_weight(name='kernel',
                                      shape=(input_shape[0][1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # 一定要在最后调用它

    def call(self, x):
        assert isinstance(x, list)
        a, b = x
        return [K.dot(a, self.kernel) + b, K.mean(b, axis=-1)]

    def compute_output_shape(self, input_shape):
        assert isinstance(input_shape, list)
        shape_a, shape_b = input_shape
        return [(shape_a[0], self.output_dim), shape_b[:-1]]

上面的文字都是来自文档,这里没有给出具体的实现功能,我下面以keras自己实显得Dense层为例来加以说明:

2.1 Dense层解析

from .. import backend as K
from .. import activations
from .. import initializers
from .. import regularizers
from .. import constraints
from ..engine.base_layer import InputSpec
from ..engine.base_layer import Layer
from ..utils.generic_utils import func_dump
from ..utils.generic_utils import func_load
from ..utils.generic_utils import deserialize_keras_object
from ..utils.generic_utils import has_arg
from ..utils import conv_utils
from ..legacy import interfaces

class Dense(Layer):
   """
   继承自Layer层,这里的大片说明文档我省略了,因为它的功能我们都知道,我们重点看功能是怎么实现的
   """
   """Dense类的构造函数,一般包涵kernel,bias的initializer初始化、regularizer正则化、约束项constraint等内容。
注意:下面的activations、initializers、regularizers、constraints等都是在来自于上面import导入的模块哦
   """
    @interfaces.legacy_dense_support
    def __init__(self, units,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(Dense, self).__init__(**kwargs) # 调用父类的构造函数
        self.units = units
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = InputSpec(min_ndim=2)
        self.supports_masking = True

    """
    这是第一个必须要实现的方法————在这里定义权重以及偏置项
    这里需要注意的是 self.add_weight()这个方法是哪里来的,他其实是定义在基类Layer里面的一个方法,后面会说到。一般的格式为如下:
    self.weight=self.add_weight(self,name,shape,dtype=None,initializer=None,regularizer=None,trainable=True,constraint=None)
当然这里的initializer、regularizer、trainable、constraint这几个参数一般都是来自构造函数。
    """
    def build(self, input_shape):
        assert len(input_shape) >= 2
        input_dim = input_shape[-1]

        self.kernel = self.add_weight(shape=(input_dim, self.units),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})

        self.built = True  # 最后这句话一定要加上

    """
    call是这个层的功能实现的地方,他接受的参数是inputs,返回output,这是搭建model的关键所在,在这个函数里面实现这个层的运算
    """
    def call(self, inputs):
        output = K.dot(inputs, self.kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias, data_format='channels_last')
        if self.activation is not None:
            output = self.activation(output)
        return output
    """
    这个函数是告诉我经过运算之后热输出的形状
    """
    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) >= 2
        assert input_shape[-1]
        output_shape = list(input_shape)
        output_shape[-1] = self.units
        return tuple(output_shape)
    
    """ 
    这个是方面获取这个层的配置信息的。可以不要。
    """
    def get_config(self):
        config = {
            'units': self.units,
            'activation': activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'activity_regularizer':
                regularizers.serialize(self.activity_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint)
        }
        base_config = super(Dense, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

2.2 三个核心方法的解析

(1)build方法——定义这个曾“固有的”权重信息(记者这里面定义kernel,bias信息),定义这个方法时候的几个注意点

第一:函数原型为 def build(self, input_shape): 及接受input_shape参数

第二:权重的定义方式。self.weight=self.add_weight(... ...)

"""
权重定义的一般的格式为如下:
    self.weight=self.add_weight(self,name,shape,dtype=None,initializer=None,regularizer=None,trainable=True,constraint=None)
当然这里的initializer、regularizer、trainable、constraint这几个参数一般都是来自构造函数。
"""
def build(self, input_shape):
    self.kernel = self.add_weight(shape=(input_dim, self.units),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

    self.bias = self.add_weight(shape=(self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
    self.built = True  # 最后这句话一定要加上

第三:最后一定要加一句话

self.built = True  或者是
super(MyLayer, self).build(input_shape)  # 一定要在最后调用它

 (2)call方法——这是“层功能”的实现方法,定义这个方法时候的几个注意点

他接受的参数为inputs,输出为output,这是我们后面可以通过Dense(... ...)(x) 这样操作的关键。定义如下:

def call(self, inputs):
    ... ...
    return output

(3)compute_output_shape方法——如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。

它接受input_shape参数,返回更改之后的形状output_shape,定义如下:

def compute_output_shape(self, input_shape):
    ... ...
    return tuple(output_shape)

2.3 基类Layer中的定义

虽然在自定义类的时候,上面的三个方法,build、call、compute_outout_shape 是必须要有的,但是还有很多其他的方法和属性可以帮助我们更好的实现自定义层。下面来该概括性的看一下:

class Layer(object):
    def __init__(self, **kwargs):
     
    @staticmethod
    def _node_key(layer, node_index):

    @property
    def losses(self):
    
    @property
    def updates(self):
      
    @property
    def built(self):
       
    @built.setter
    def built(self, value):
   
    @property
    def trainable_weights(self):

    @trainable_weights.setter
    def trainable_weights(self, weights):
    
    @property
    def non_trainable_weights(self):
      
    @non_trainable_weights.setter
    def non_trainable_weights(self, weights):

    @interfaces.legacy_add_weight_support
    def add_weight(self,name,shape,dtype=None,initializer=None,regularizer=None,
                   trainable=True,
                   constraint=None):
     
    def assert_input_compatibility(self, inputs):
       
    def call(self, inputs, **kwargs):
      
    def __call__(self, inputs, **kwargs):

    def _add_inbound_node(self, input_tensors, output_tensors,
                          input_masks, output_masks,
                          input_shapes, output_shapes, arguments=None):

    def compute_output_shape(self, input_shape):
     
    def compute_mask(self, inputs, mask=None):

    def _get_node_attribute_at_index(self, node_index, attr, attr_name):

    def get_input_shape_at(self, node_index):

    def get_output_shape_at(self, node_index):
       
    def get_input_at(self, node_index):
       
    def get_output_at(self, node_index):

    def get_input_mask_at(self, node_index):
      
    def get_output_mask_at(self, node_index):
       
    @property
    def input(self):

    @property
    def output(self):

    @property
    def input_mask(self):

    @property
    def output_mask(self):
        
    @property
    def input_shape(self):
       
    @property
    def output_shape(self):

    def add_loss(self, losses, inputs=None):

    def add_update(self, updates, inputs=None):

    def get_updates_for(self, inputs):
       
    def get_losses_for(self, inputs):

    @property
    def weights(self):

    def set_weights(self, weights):
       
    def get_weights(self):
       
    def get_config(self):
        
    @classmethod
    def from_config(cls, config):

    def count_params(self):

 

  • 39
    点赞
  • 180
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值