Tensorflow2自定义Layers之__init__,build和call详解
参考官方链接:https://tensorflow.google.cn/tutorials/customization/custom_layers
闲言碎语:
如果想要自定义自己的Layer,那么使用tf.keras.Layer 来创建自己的类是必不可少的。
但是笔者最近在搞自己的毕设时,遇到了自定义层和tf.function的Error:
ValueError: tf.function-decorated function tried to create variables on non-first call.
初步判断为:add_weight和tf.function冲突。在尝试解决此冲突时,发现对def build()函数的理解还不够深刻。所以作此记录。
–init–,build和call
- —init—函数:这个函数用于对所有独立的输入进行初始化。(独立的输入:特指和训练数据无关的输入)(这个函数仅被执行一次)
- build函数:这个函数用于当你知道输入Tensor的shape后,完成其余的初始化。(即需要知道数据的shape,但不需要知道数据的具体值)(注意:这个函数仅在Call被第一次调用时执行)
- call函数:这个函数就是用来前向计算的函数了。
注意:build函数并非是必须的。如果你事先知道Tensor的shape,即数据是显式的,那么完全可以在—init—中执行build内初始化的内容,从而不需要定义build函数。当然build也有它的好处,就是当你事先不知道数据的shape时,build的作用就凸显出来了,例如,假如你不知道输入图片的size是224x224还是512x512,那么如何初始化自己的Layer呢?build提供了此类的解决方案。还有一点,笔者认为它可能解决上述笔者遇到的问题,不过目前还没有验证。在文末应该会验证。
为了更好的了解上述函数的特点,下面举一个简单的例子:
import tensorflow as tf
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
print('init 被执行')
self.num_outputs = num_outputs
self.i = 0
print('Init:This is i',self.i)
self.i = self.i +1
def build(self,input_shape):
print('build 被执行')
print('input_shape',input_shape)
print('Build:This is i',self.i)
self.kernel = self.add_weight("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
def call(self, input):
print('call 被执行')
return tf.matmul(input, self.kernel)
layer = MyDenseLayer(10)
_ = layer(tf.zeros([10, 5])) # Calling the layer `.builds` it.
print([var.name for var in layer.trainable_variables])
_ = layer(tf.ones([10, 5]))
print([var.name for var in layer.trainable_variables])
输出结果如下所示:
init 被执行
Init:This is i 0
build 被执行
input_shape (10, 5)
Build:This is i 1
call 被执行
['my_dense_layer/kernel:0']
call 被执行
['my_dense_layer/kernel:0']
结合上述代码看一看,是不是很容易懂了的。
init先被执行,且仅执行一次。build也是,第一次调用Call函数执行。call每次被调用都会被执行。
总结
通过上述解释,发现了笔者的问题并不出现在Build上。但是经过了近一天的查找,最终也成功定位到了问题的所在点。
搞了这么多,走偏了很多路(包括这篇blog的诞生)。不过最终定位到了问题,Nice!
快一点了,睡觉!
(后续补充:该问题已经解决。
- 出现的问题:我的模型是CNN和GCN的嵌套,在CNN的call函数内调用到了GCN,则相当于在CNN的call被调用的时候执行了GCN model,而GCN内的call又包含tf.variable函数,则相当于循环调用了tf.variable函数,因此出现了问题。
- 解决方案:将GCN的tf.variable放到–init–函数内,且同时调用GCN和CNN的–init–函数,避免了tf.variable被循环问题)