tf2.0报tf.function错误

错误信息全部信息:

WARNING:tensorflow:11 out of the last 11 calls to <function
train..train_step at 0x7f6d1843e840> triggered tf.function
retracing. Tracing is expensive and the excessive number of tracings
is likely due to passing python objects instead of tensors. Also,
tf.function has experimental_relax_shapes=True option that relaxes
argument shapes that can avoid unnecessary retracing. Please refer to
https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args
and https://www.tensorflow.org/api_docs/python/tf/function for more
details.

StackOverflow上的解答多有出入,一般是他们写的代码本身问题。和上面错误信息描述的不一致。经过查看tf2.0官方的示例代码:https://tensorflow.google.cn/tutorials/text/transformer
找到如下代码:

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]
  
  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  
  with tf.GradientTape() as tape:
    predictions, _ = transformer(inp, tar_inp, 
                                 True, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    loss = loss_function(tar_real, predictions)

注意那个train_step_signature和tf.function(input_signature=train_step_signature)。
个人推测train函数有入参时候,这里必须声明入参的类型。比如这个例子中有两个入参则在train_step_signature中声明

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值