显然,这三个函数都是从tf.keras.layers.Layer处继承而来的。
_init_ 可以在其中进行所有与输入无关的初始化
build 知道输入张量的形状,并可以进行其余的初始化
call 可以在其中进行前向计算
官方API的例子:
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_weight("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
def call(self, inputs):
return tf.matmul(inputs, self.kernel)
layer = MyDenseLayer(10)
_ = layer(tf.zeros([10, 5]))
print([var.name for var in layer.trainable_variables])
输出: ['my_dense_layer/kernel:0']
从直观上理解,似乎__init__()和build()函数都在对Layer进行初始化,都初始化了一些成员函数,而call()函数则是在该layer被调用时执行。
简单翻译,就是说官方推荐凡是tf.keras.layers.Layer的派生类都要实现__init__(),build(), call()这三个方法
init():保存成员变量的设置
build():在call()函数第一次执行时会被调用一次,这时候可以知道输入数据的shape。
返回去看一看,果然是__init__()函数中只初始化了输出数据的shape,而输入数据的shape需要在build()函数中动态获取,这也解释了为什么在有__init__()函数时还需要使用build()函数
call(): call()函数把对象当做函数来使用,即当其被调用时会被执行。当call被第一次调用的时候,会先执行build()方法初始化变量,但后面再调用到call的时候,是不会再去执行build()方法初始化变量。
从上面的官方例子可以简单梳理脉络,但是对我来,发现程序正如我们前面提到把对象当做函数来使用
_ = layer(tf.zeros([10, 5]))
调用了call()
并没有调用 build方法 去源码 tensorflow.python.keras.layers.Layer.call方法中查看:
input_shapes = None
if all(hasattr(x, 'shape') for x in input_list):
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
if not hasattr(self.build, '_is_default'):
with tf_utils.maybe_init_scope(self):
self.build(input_shape)
发现call方法中调用了build 并且提供参数 input_shape。
- 为什么有些要重写build呢?
官方是这么写的:
the advantage of creating them in build is that it enables late variable creation based on the shape of the inputs the layer will operate on.
因为可以单独调用build生成输入权重信息,支持基于层将操作的输入的形状的后期变量创建。
即初始化时的操作可能需要自定义。
例:
来源自 DeepCtr的deepctr/layers/utils.py
def build(self, input_shape):
if self.use_bias:
self.bias = self.add_weight(name='linear_bias',
shape=(1,),
initializer=tf.keras.initializers.Zeros(),
trainable=True)
if self.mode == 1:
self.kernel = self.add_weight(
'linear_kernel',
shape=[int(input_shape[-1]), 1],
initializer=tf.keras.initializers.glorot_normal(self.seed),
regularizer=tf.keras.regularizers.l2(self.l2_reg),
trainable=True)
elif self.mode == 2:
self.kernel = self.add_weight(
'linear_kernel',
shape=[int(input_shape[1][-1]), 1],
initializer=tf.keras.initializers.glorot_normal(self.seed),
regularizer=tf.keras.regularizers.l2(self.l2_reg),
trainable=True)
super(Linear, self).build(input_shape) # Be sure to call this somewhere!
参考 时光碎了天 的博客 | beking00700 的博客