Tensorflow2 Bug: triggered tf.function retracing (已解决)

Tensorflow2 Bug: triggered tf.function retracing

背景

Bug在通过tf.function模式下运行时出现,应该与AutoGraph模式本身的使用机制有关, tensorflow中关于该Bug的帖子:

本人tensorflow2也用了快一年了,之前也写过好多相关代码/项目,都没有出现过该问题,这次突然出现这个问题, 好郁闷!

尝试

1.根据提示, tf.function具有experimental_relax_shapes = True选项,该选项可放宽参数形状,从而避免不必要的跟踪, 加上该参数后,即@tf.function(experimental_relax_shapes = True)发现并不管用,这就说明,并不是这个原因。实际我的程序中每个batch的尺寸都是固定的。
2.参考博客:https://blog.csdn.net/xygl2009/article/details/104443654, 通过tf.functioninput_signature参数来给方法的每个参数定义signature,如下:试了下,还是不管用。

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
	xxx  

3.最终我发现了问题所在:
@tf.function()
def train_step(batch_id: int) :
xxxx
函数train_step中由一个int类型的参数,然后利用该int参数来生成Tensor类型变量,并参与计算。这会导致一个问题:int参数是Python参数,在这些情况下,Python值的变化可能会触发不必要的回溯。举例来说,遍历batch循环训练,AutoGraph将动态展开。换句话说,每次call train_step时都会触发回溯,从而大大延迟程序执行,出现这个警告
解决方式: 通常,Python的参数被用来控制超参数和图的结构-例如,num_layers=10或training=True或nonlinearity=‘relu’。只是这样的话,并没有什么问题,我之前就这样用过, 而在这里,由于Python的参数参与了图的计算, 再用tf.function一静态图的方式运行时就会出现这个问题。然后我将int类型的参数在传入是转换为Tensor格式, 问题自动消失了。

References

1.https://www.tensorflow.org/api_docs/python/tf/function
2.https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args
3.https://stackoverflow.com/questions/61647404/tensorflow-2-getting-warningtensorflow9-out-of-the-last-9-calls-to-function
4.

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MasterQKK 被注册

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值