TensorFlowflow训练模型报错: Invalid argument: Name: <unknown>, Feature: weight (data type: float) is required but could not be found.
问题现场:
2020-09-09 11:03:49.962924: W tensorflow/core/framework/op_kernel.cc:1502] OP_REQUIRES failed at example_parsing_ops.cc:144 : Invalid argument: Name: <unknown>, Feature: weight (data type: float) is required but could not be found.
2020-09-09 11:03:49.963781: W tensorflow/core/framework/op_kernel.cc:1502] OP_REQUIRES failed at example_parsing_ops.cc:144 : Invalid argument: Name: <unknown>, Feature: weight (data type: float) is required but could not be found.
Traceback (most recent call last):
File "train_model.py", line 235, in <module>
tf.app.run(main)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "train_model.py", line 212, in main
classifier.train(input_fn=lambda: data_helper.train_input_fn(file_list, FLAGS.batch_size), hooks=hooks)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 367, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1158, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1192, in _train_model_default
saving_listeners)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1484, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 754, in run
run_metadata=run_metadata)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1252, in run
run_metadata=run_metadata)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1353, in run
raise six.reraise(*original_exc_info)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1338, in run
return self._sess.run(*args, **kwargs)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1411, in run
run_metadata=run_metadata)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1169, in run
return self._sess.run(*args, **kwargs)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
run_metadata)
File "/Users/wang/anaconda3/envs/tf/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Name: <unknown>, Feature: weight (data type: float) is required but could not be found.
[[{{node ParseExample/ParseExample}}]]
[[IteratorGetNext]]
思路:
报错说weight (data type: float) is required but could not be found,代码中的"weight"没找到。我的训练数据是json格式,原先里面有weight这个key,在训练模型的代码中,也有时候用到该字段的值,并进行了解析。仔细检查了一下生成tfrecord格式数据的流程,发现有一个条件限制的严格,导致weight字段没写到最终的样本中。因此这里训练模型解析"weight"报错。
代码中使用的:
record_format['weight'] = tf.io.FixedLenFeature([1], tf.float32),正是这里解析不出来weight报错。
解决办法:
1.回去重新生成数据,使得样本中包含该字段。
2.如果代码代码中这个字段的内容,对你训练模型没有影响,可以从训练模型的代码中,将该句代码删除。这样避免重新跑数据。
参考:
https://github.com/tensorflow/models/issues/5582