最近入坑了一个项目,是有关pretraining的,模型已经训练好了,需要做一些downstream任务的测试。然而项目是用tensorflow写的,作为一个重度pytorch用户,我只听说过一些诸如session,eager execution等云里雾里的名词,和Keras整合时的混乱,以及从TF1迁移到TF2时的痛苦。这次硬着头皮上了。
在搭好针对downstream任务的训练和测试代码后,开始finetune。Training accuracy看起来不错,但是validation accuracy怎么这么低,难道overfit这么严重吗?我把validation set的路径指向training data,按说这时的validation accuracy应该和training大致相等,结果还是低到不可思议!那应该是计算accuracy的代码出bug了吧。于是我想打印出predictions来看看,就这样挖出了一个无底的rabbit hole。。。
先说这个项目的一些结构。代码中用了一个叫Estimator的东西,它打包了从训练到测试再到部署一系列环节,而我们只要给它写两个函数:一个model_fn
和一个input_fn
,其中前者运行模型,后者提供输入数据。对于每一个batch, model_fn
返回一个EstimatorSpec,包含了训练和记录所需的信息:
return tf.contrib.tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,
loss=losses['loss'],
train_op=train_op,
eval_metrics=(metric_fn, [evaluation[metric]]),
scaffold_fn=scaffold_fn,
host_call=host_call,
predictions=evaluation)
首先试了试用print()打印batch的预测结果:
print(evaluation['predictions'])
然而并没有效果。原来是Estimator会以graph execution而不是eager execution来执行,在建图时只会保留必要的tensor operation,而舍弃像print()之类的效果。
然后我在TF中找出两个打印函数:tf.print()和tf.compat.v1.Print()。这里说一下compat.v1这个模块,它是为了从TF1迁移至TF2时的后向兼容性。所有的TF1函数都被移到了compat.v1中,保留了原有的函数签名和语法,并且还支持在TF2的runtime里运行(只是在优化度方面不及TF2原生的代码)。用compat.v1中的函数写出的代码,同时支持在TF1和TF2中运行,并且能够很方便地被升级成TF2原生代码。但这里的tf.compat.v1.Print()在文档中已被标明deprecated,如有需要应该用tf.print()。
试着用tf.print()打印:
tf.print(evaluation['predictions'])
然而还是没有效果。这个操作应该会被加到graph当中执行的啊?然后在tf.print()的文档里找到了这么一小段话:
好吧,谁让我在用TF1,而且是graph execution呢。Estimator没有session,那我就加一个control_dependencies()试试:
print_op = tf.print(losses['loss'])
with tf.control_dependencies([print_op]):
losses['loss'] = tf.identity(losses['loss'])
好嘛,终于有动静了,给我糊了一脸stack trace,唯一有用的是这么一段:
ERROR:tensorflow:Error recorded from evaluation_loop: From /job:worker/replica:0/task:0:
Compilation failure: Detected unsupported operations when trying to compile graph _functionalize_body_2[] on XLA_TPU_JIT: StringFormat (No registered 'StringFormat' OpKernel for XLA_TPU_JIT devices compatible with node node StringFormat (defined at modeling.py:417)
. Registered: device='CPU'
)node StringFormat (defined at modeling.py:417)
[[LoopCond]]
TPU compilation failed
[[tpu_compile_succeeded_assert/_6684163811970357570/_1671]]
Errors may have originated from an input operation.
Input Source operations connected to node StringFormat:
truediv (defined at modeling.py:323)
我把这个“No registered StringFormat OpKernel“错误搜了一下,感觉完全没有相关的讨论。又走到了死胡同。
这时我上朋友圈问了一圈,有好几个好心的朋友提议我试试LoggingTensorHook。大概是长这样:
hook = tf.train.LoggingTensorHook({'accuracy': evaluation['accuracy']}, every_n_iter=10)
return tf.contrib.tpu.TPUEstimatorSpec(evaluation_hooks=[hook],
...)
结果这次运行到Restoring model parameters时竟然卡住了!没有任何报错信息,只是卡在这里十多分钟完全不走。加一个hook会影响parameter loading真的想不通。
这里再吐槽一下Tensorflow的doc。我是在StackOverflow里看到的tf.train.LoggingTensorHook,然而这个类在TF2的文档里根本找不到,原来这个是TF1中的命名,在TF2中被挪到了tf.estimator.LoggingTensorHook。然而,在compat.v1的文档中,LoggingTensorHook并没有出现在tf.train或tf.estimator模块下,仿佛它在TF1中从没有出现过一般。最后我是特意翻出TF 1.14的文档来(已经被archive到了GitHub里)才找到,并发现了这段信息:
讲道理,改名这种事情应该放在最新版文档的显眼位置,要翻这么久才能找到的信息,不能不让人感到confusing。
实在不行,把Estimator拆成一个自己写的loop,然后用eager execution行不行呢?这样上面的print()以及tf.print()都应该work了吧?
因为TF1默认graph execution,我们先要启用eager execution:
tf.compat.v1.enable_eager_execution()
把TPU设置好:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_host('10.240.1.10:8470')
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
然后写我们的loop:
with strategy.scope():
input_fn = tvqa_dataloader.input_fn_builder(config, is_training=False)
dataset = input_fn(params={'batch_size': 8})
for data in dataset:
features, labels = data
spec = model_fn(features, labels, tf.estimator.ModeKeys.EVAL, params={})
metric_fn, tensors = spec.eval_metrics
acc = tensors[0]
print(tf.make_ndarray(acc))
tf.get_variable_scope().reuse_variables()
这次运行的时候卡在了model_fn()中创建model的环节,仍然是没有任何报错信息。
既然TF1的文档和支持这么差了,如果把代码升级成TF2会不会解决这个无法print的问题呢。。
<One hour later...>
把代码库升级到TF2了,过程还算顺利,唯一的glitch是slim库作为contrib被踢出了TF2,于是我得重新安装tf_slim包,并更新了一遍各处tfexample.decoder的函数名。
- 先上tf.print(),这次应该不用control_dependency了,但是这回又出了些新东西:
Exception in thread Thread-1:
Traceback (most recent call last):
File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/threading.py", line 917, in _bootstrap_inner
self.run()
File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/tensorflow/python/tpu/preempted_hook.py", line 87, in run
recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access
File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/tensorflow/python/tpu/client/client.py", line 264, in recoverable
elif FLAGS.runtime_oom_exit and self._oom_event():
File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/absl/flags/_flagvalues.py", line 498, in __getattr__
raise _exceptions.UnparsedFlagAccessError(error_message)
absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --runtime_oom_exit before flags were parsed.
2. 再试LoggingTensorHook,现在给了这样一个错误:
ValueError: Passed Tensor("Cast_372:0", shape=(8,), dtype=float32) should have graph attribute that is equal to current graph <tensorflow.python.framework.ops.Graph object at 0x7f5c5cf12ef0>.
又去StackOverflow搜了一圈,感觉这个错误是因为tensor不是从model_fn()里出来的。可我这个accuracy和loss确实都是model_fn()计算出来的啊。TF的GitHub上有一个issue和我遇到的问题相似,但是楼主并没有follow-up自己是如何解决的。
3. 写成eager execution,错误是:
ValueError: Attempt to convert a value (functools.partial(<tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PolynomialDecay object at 0x7fee44640470>, TPUMirroredVariable:{
0: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>,
1: <tf.Variable 'global_step/replica_1:0' shape=() dtype=int64, numpy=0>,
2: <tf.Variable 'global_step/replica_2:0' shape=() dtype=int64, numpy=0>,
3: <tf.Variable 'global_step/replica_3:0' shape=() dtype=int64, numpy=0>,
4: <tf.Variable 'global_step/replica_4:0' shape=() dtype=int64, numpy=0>,
5: <tf.Variable 'global_step/replica_5:0' shape=() dtype=int64, numpy=0>,
6: <tf.Variable 'global_step/replica_6:0' shape=() dtype=int64, numpy=0>,
7: <tf.Variable 'global_step/replica_7:0' shape=() dtype=int64, numpy=0>
})) with an unsupported type (<class 'functools.partial'>) to a Tensor.
总结一下,这次上手Tensorflow的体验可以说是极差的,官方文档在处理TF1到2的迁移时存在很多混乱,很多函数的详细功能、使用条件、常见问题说得都不是很清楚,很多常见问题在StackOverflow上的讨论也不充分。
结论:还是Pytorch香!
最后祭出这张图: