Tensorflow2 Bug: triggered tf.function retracing
背景
该Bug在通过tf.function模式下运行时出现,应该与AutoGraph模式本身的使用机制有关, tensorflow中关于该Bug的帖子:
本人tensorflow2也用了快一年了,之前也写过好多相关代码/项目,都没有出现过该问题,这次突然出现这个问题, 好郁闷!
尝试
1.根据提示, tf.function
具有experimental_relax_shapes = True
选项,该选项可放宽参数形状,从而避免不必要的跟踪, 加上该参数后,即@tf.function(experimental_relax_shapes = True)
发现并不管用,这就说明,并不是这个原因。实际我的程序中每个batch的尺寸都是固定的。
2.参考博客:https://blog.csdn.net/xygl2009/article/details/104443654, 通过tf.function的input_signature参数来给方法的每个参数定义signature,如下:试了下,还是不管用。
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
xxx
3.最终我发现了问题所在:
@tf.function()
def train_step(batch_id: int) :
xxxx
函数train_step中由一个int类型的参数,然后利用该int参数来生成Tensor类型变量,并参与计算。这会导致一个问题:int参数是Python参数,在这些情况下,Python值的变化可能会触发不必要的回溯。举例来说,遍历batch循环训练,AutoGraph将动态展开。换句话说,每次call train_step时都会触发回溯,从而大大延迟程序执行,出现这个警告
。
解决方式: 通常,Python的参数被用来控制超参数和图的结构-例如,num_layers=10或training=True或nonlinearity=‘relu’。只是这样的话,并没有什么问题,我之前就这样用过, 而在这里,由于Python的参数参与了图的计算, 再用tf.function一静态图的方式运行时就会出现这个问题。然后我将int类型的参数在传入是转换为Tensor格式
, 问题自动消失了。
References
1.https://www.tensorflow.org/api_docs/python/tf/function
2.https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args
3.https://stackoverflow.com/questions/61647404/tensorflow-2-getting-warningtensorflow9-out-of-the-last-9-calls-to-function
4.