tensorflow测试代码_天坑Tensorflow

251c3dde3cfd33389ef001dc72389f49.png

最近入坑了一个项目,是有关pretraining的,模型已经训练好了,需要做一些downstream任务的测试。然而项目是用tensorflow写的,作为一个重度pytorch用户,我只听说过一些诸如session,eager execution等云里雾里的名词,和Keras整合时的混乱,以及从TF1迁移到TF2时的痛苦。这次硬着头皮上了。

在搭好针对downstream任务的训练和测试代码后,开始finetune。Training accuracy看起来不错,但是validation accuracy怎么这么低,难道overfit这么严重吗?我把validation set的路径指向training data,按说这时的validation accuracy应该和training大致相等,结果还是低到不可思议!那应该是计算accuracy的代码出bug了吧。于是我想打印出predictions来看看,就这样挖出了一个无底的rabbit hole。。。

先说这个项目的一些结构。代码中用了一个叫Estimator的东西,它打包了从训练到测试再到部署一系列环节,而我们只要给它写两个函数:一个model_fn和一个input_fn,其中前者运行模型,后者提供输入数据。对于每一个batch, model_fn返回一个EstimatorSpec,包含了训练和记录所需的信息:

return tf.contrib.tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,
                                       loss=losses['loss'],
                                       train_op=train_op,
                                       eval_metrics=(metric_fn, [evaluation[metric]]),
                                       scaffold_fn=scaffold_fn,
                                       host_call=host_call,
                                       predictions=evaluation)

首先试了试用print()打印batch的预测结果:

print(evaluation['predictions'])

然而并没有效果。原来是Estimator会以graph execution而不是eager execution来执行,在建图时只会保留必要的tensor operation,而舍弃像print()之类的效果。


然后我在TF中找出两个打印函数:tf.print()和tf.compat.v1.Print()。这里说一下compat.v1这个模块,它是为了从TF1迁移至TF2时的后向兼容性。所有的TF1函数都被移到了compat.v1中,保留了原有的函数签名和语法,并且还支持在TF2的runtime里运行(只是在优化度方面不及TF2原生的代码)。用compat.v1中的函数写出的代码,同时支持在TF1和TF2中运行,并且能够很方便地被升级成TF2原生代码。但这里的tf.compat.v1.Print()在文档中已被标明deprecated,如有需要应该用tf.print()。

试着用tf.print()打印:

tf.print(evaluation['predictions'])

然而还是没有效果。这个操作应该会被加到graph当中执行的啊?然后在tf.print()的文档里找到了这么一小段话:

800a984e4b20f3803b1d45e88c3c752d.png

好吧,谁让我在用TF1,而且是graph execution呢。Estimator没有session,那我就加一个control_dependencies()试试:

print_op = tf.print(losses['loss'])
with tf.control_dependencies([print_op]):
    losses['loss'] = tf.identity(losses['loss'])

好嘛,终于有动静了,给我糊了一脸stack trace,唯一有用的是这么一段:

ERROR:tensorflow:Error recorded from evaluation_loop: From /job:worker/replica:0/task:0:
Compilation failure: Detected unsupported operations when trying to compile graph _functionalize_body_2[] on XLA_TPU_JIT: StringFormat (No registered 'StringFormat' OpKernel for XLA_TPU_JIT devices compatible with node node StringFormat (defined at modeling.py:417)
        .  Registered:  device='CPU'
)node StringFormat (defined at modeling.py:417)
         [[LoopCond]]
        TPU compilation failed
         [[tpu_compile_succeeded_assert/_6684163811970357570/_1671]]

Errors may have originated from an input operation.
Input Source operations connected to node StringFormat:
 truediv (defined at modeling.py:323)

我把这个“No registered StringFormat OpKernel“错误搜了一下,感觉完全没有相关的讨论。又走到了死胡同。


这时我上朋友圈问了一圈,有好几个好心的朋友提议我试试LoggingTensorHook。大概是长这样:

hook = tf.train.LoggingTensorHook({'accuracy': evaluation['accuracy']}, every_n_iter=10)
return tf.contrib.tpu.TPUEstimatorSpec(evaluation_hooks=[hook],
                                       ...)

结果这次运行到Restoring model parameters时竟然卡住了!没有任何报错信息,只是卡在这里十多分钟完全不走。加一个hook会影响parameter loading真的想不通。

这里再吐槽一下Tensorflow的doc。我是在StackOverflow里看到的tf.train.LoggingTensorHook,然而这个类在TF2的文档里根本找不到,原来这个是TF1中的命名,在TF2中被挪到了tf.estimator.LoggingTensorHook。然而,在compat.v1的文档中,LoggingTensorHook并没有出现在tf.train或tf.estimator模块下,仿佛它在TF1中从没有出现过一般。最后我是特意翻出TF 1.14的文档来(已经被archive到了GitHub里)才找到,并发现了这段信息:

000809668d2f864139470f09f0e96cc0.png

讲道理,改名这种事情应该放在最新版文档的显眼位置,要翻这么久才能找到的信息,不能不让人感到confusing。


实在不行,把Estimator拆成一个自己写的loop,然后用eager execution行不行呢?这样上面的print()以及tf.print()都应该work了吧?

因为TF1默认graph execution,我们先要启用eager execution:

tf.compat.v1.enable_eager_execution()

把TPU设置好:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_host('10.240.1.10:8470')
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)

然后写我们的loop:

with strategy.scope():
    input_fn = tvqa_dataloader.input_fn_builder(config, is_training=False)
    dataset = input_fn(params={'batch_size': 8})
    for data in dataset:
        features, labels = data
        spec = model_fn(features, labels, tf.estimator.ModeKeys.EVAL, params={})
        metric_fn, tensors = spec.eval_metrics
        acc = tensors[0]
        print(tf.make_ndarray(acc))
        tf.get_variable_scope().reuse_variables()

这次运行的时候卡在了model_fn()中创建model的环节,仍然是没有任何报错信息。


既然TF1的文档和支持这么差了,如果把代码升级成TF2会不会解决这个无法print的问题呢。。

<One hour later...>

把代码库升级到TF2了,过程还算顺利,唯一的glitch是slim库作为contrib被踢出了TF2,于是我得重新安装tf_slim包,并更新了一遍各处tfexample.decoder的函数名。

  1. 先上tf.print(),这次应该不用control_dependency了,但是这回又出了些新东西:
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/threading.py", line 917, in _bootstrap_inner
    self.run()
  File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/tensorflow/python/tpu/preempted_hook.py", line 87, in run
    recoverable = self._cluster._cloud_tpu_client.recoverable()  # pylint: disable=protected-access
  File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/tensorflow/python/tpu/client/client.py", line 264, in recoverable
    elif FLAGS.runtime_oom_exit and self._oom_event():
  File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/absl/flags/_flagvalues.py", line 498, in __getattr__
    raise _exceptions.UnparsedFlagAccessError(error_message)
absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --runtime_oom_exit before flags were parsed.

2. 再试LoggingTensorHook,现在给了这样一个错误:

ValueError: Passed Tensor("Cast_372:0", shape=(8,), dtype=float32) should have graph attribute that is equal to current graph <tensorflow.python.framework.ops.Graph object at 0x7f5c5cf12ef0>.

又去StackOverflow搜了一圈,感觉这个错误是因为tensor不是从model_fn()里出来的。可我这个accuracy和loss确实都是model_fn()计算出来的啊。TF的GitHub上有一个issue和我遇到的问题相似,但是楼主并没有follow-up自己是如何解决的。

3. 写成eager execution,错误是:

ValueError: Attempt to convert a value (functools.partial(<tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PolynomialDecay object at 0x7fee44640470>, TPUMirroredVariable:{
  0: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>,
  1: <tf.Variable 'global_step/replica_1:0' shape=() dtype=int64, numpy=0>,
  2: <tf.Variable 'global_step/replica_2:0' shape=() dtype=int64, numpy=0>,
  3: <tf.Variable 'global_step/replica_3:0' shape=() dtype=int64, numpy=0>,
  4: <tf.Variable 'global_step/replica_4:0' shape=() dtype=int64, numpy=0>,
  5: <tf.Variable 'global_step/replica_5:0' shape=() dtype=int64, numpy=0>,
  6: <tf.Variable 'global_step/replica_6:0' shape=() dtype=int64, numpy=0>,
  7: <tf.Variable 'global_step/replica_7:0' shape=() dtype=int64, numpy=0>
})) with an unsupported type (<class 'functools.partial'>) to a Tensor.

总结一下,这次上手Tensorflow的体验可以说是极差的,官方文档在处理TF1到2的迁移时存在很多混乱,很多函数的详细功能、使用条件、常见问题说得都不是很清楚,很多常见问题在StackOverflow上的讨论也不充分。

结论:还是Pytorch香!

最后祭出这张图:

f9e3ef602d0f995c9cf456d80a1d6cbb.png
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值