Tensorflow2自定义Layers之__init__,build和call详解

Tensorflow2自定义Layers之__init__,build和call详解

参考官方链接:https://tensorflow.google.cn/tutorials/customization/custom_layers

闲言碎语:

如果想要自定义自己的Layer,那么使用tf.keras.Layer 来创建自己的类是必不可少的。

但是笔者最近在搞自己的毕设时,遇到了自定义层和tf.function的Error:
ValueError: tf.function-decorated function tried to create variables on non-first call.
初步判断为:add_weight和tf.function冲突。在尝试解决此冲突时,发现对def build()函数的理解还不够深刻。所以作此记录。


–init–,build和call

  • —init—函数:这个函数用于对所有独立的输入进行初始化。(独立的输入:特指和训练数据无关的输入)(这个函数仅被执行一次)
  • build函数:这个函数用于当你知道输入Tensor的shape后,完成其余的初始化。(即需要知道数据的shape,但不需要知道数据的具体值)(注意:这个函数仅在Call被第一次调用时执行)
  • call函数:这个函数就是用来前向计算的函数了。

注意:build函数并非是必须的。如果你事先知道Tensor的shape,即数据是显式的,那么完全可以在—init—中执行build内初始化的内容,从而不需要定义build函数。当然build也有它的好处,就是当你事先不知道数据的shape时,build的作用就凸显出来了,例如,假如你不知道输入图片的size是224x224还是512x512,那么如何初始化自己的Layer呢?build提供了此类的解决方案。还有一点,笔者认为它可能解决上述笔者遇到的问题,不过目前还没有验证。在文末应该会验证。

为了更好的了解上述函数的特点,下面举一个简单的例子:

import tensorflow as tf
class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    print('init 被执行')
    self.num_outputs = num_outputs
    self.i = 0
    print('Init:This is i',self.i)
    self.i = self.i +1
  def build(self,input_shape):
    print('build 被执行')
    print('input_shape',input_shape)
    print('Build:This is i',self.i)
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]),
                                         self.num_outputs])
  
  def call(self, input):
    print('call 被执行')
    return tf.matmul(input, self.kernel)

layer = MyDenseLayer(10)
_ = layer(tf.zeros([10, 5])) # Calling the layer `.builds` it.
print([var.name for var in layer.trainable_variables])
_ = layer(tf.ones([10, 5]))
print([var.name for var in layer.trainable_variables])

输出结果如下所示:

init 被执行
Init:This is i 0
build 被执行
input_shape (10, 5)
Build:This is i 1
call 被执行
['my_dense_layer/kernel:0']
call 被执行
['my_dense_layer/kernel:0']

结合上述代码看一看,是不是很容易懂了的。
init先被执行,且仅执行一次。build也是,第一次调用Call函数执行。call每次被调用都会被执行。

总结

通过上述解释,发现了笔者的问题并不出现在Build上。但是经过了近一天的查找,最终也成功定位到了问题的所在点。

搞了这么多,走偏了很多路(包括这篇blog的诞生)。不过最终定位到了问题,Nice!

快一点了,睡觉!

(后续补充:该问题已经解决。

  • 出现的问题:我的模型是CNN和GCN的嵌套,在CNN的call函数内调用到了GCN,则相当于在CNN的call被调用的时候执行了GCN model,而GCN内的call又包含tf.variable函数,则相当于循环调用了tf.variable函数,因此出现了问题。
  • 解决方案:将GCN的tf.variable放到–init–函数内,且同时调用GCN和CNN的–init–函数,避免了tf.variable被循环问题)
  • 13
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值