问题记录
- Ubuntu 20.04 LTS
- nvidia-smi 465.27 cuda11.3
- tensorflow-gpu 1.15.5
tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [[0.919921875][0.68359375][0.517578125]...] [[0.814453125][0.587890625][0.42578125]...]
[[{{node Assert_1/AssertGuard/Assert}}]]
[[IteratorGetNext]]
完整报错如下:
Traceback (most recent call last):
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
return fn(*args)
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1349, in _run_fn
return self._call_tf_sessionrun(options, feed_dict, fetch_list,
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1441, in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node Dataset_map_transform_and_pad_input_data_fn_433}} assertion failed: [[0.919921875][0.68359375][0.517578125]...] [[0.814453125][0.587890625][0.42578125]...]
[[{{node Assert_1/AssertGuard/Assert}}]]
[[IteratorGetNext]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "model_main.py", line 108, in <module>
tf.app.run()
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/home/haotian/.local/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/haotian/.local/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "model_main.py", line 104, in main
tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/training.py", line 473, in train_and_evaluate
return executor.run()
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/training.py", line 613, in run
return self.run_local()
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/training.py", line 710, in run_local
self._estimator.train(
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 370, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1161, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1193, in _train_model_default
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1494, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 750, in run
return self._sess.run(
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1255, in run
return self._sess.run(
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1360, in run
raise six.reraise(*original_exc_info)
File "/home/haotian/.local/lib/python3.8/site-packages/six.py", line 703, in reraise
raise value
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1345, in run
return self._sess.run(*args, **kwargs)
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1413, in run
outputs = _WrappedSession.run(
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1176, in run
return self._sess.run(*args, **kwargs)
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 955, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1179, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1358, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [[0.919921875][0.68359375][0.517578125]...] [[0.814453125][0.587890625][0.42578125]...]
[[{{node Assert_1/AssertGuard/Assert}}]]
[[IteratorGetNext]]
解决方法
检查数据集标签,尤其是自己做了数据增强的,很可能是xml损坏了
- 检查 xmin、xmax、ymin、ymax 有没有负值;
- 检查 xmin、xmax、ymin、ymax 是否超出图像范围,可以在生成tfrecords的csv_to_tfrecords.py里面找到如下位置,修改成这样:
for index, row in group.object.iterrows():
# xmins.append(row['xmin'] / width)
# xmaxs.append(row['xmax'] / width)
# ymins.append(row['ymin'] / height)
# ymaxs.append(row['ymax'] / height)
xmn = row['xmin'] / width
if xmn < 0.0:
xmn = 0.0
elif xmn > 1.0:
xmn = 1.0
xmins.append(xmn)
xmx = row['xmax'] / width
if xmx < 0.0:
xmx = 0.0
elif xmx > 1.0:
xmx = 1.0
xmaxs.append(xmx)
ymn = row['ymin'] / height
if ymn < 0.0:
ymn = 0.0
elif ymn > 1.0:
ymn = 1.0
ymins.append(ymn)
ymx = row['ymax'] / height
if ymx < 0.0:
ymx = 0.0
elif ymx > 1.0:
ymx = 1.0
ymaxs.append(ymx)
classes_text.append(row['class'].encode('utf8'))
classes.append(class_text_to_int(row['class']))
- 检查 xmin<xmax、ymin<ymax 是否满足,这里有一段修复xmin<xmax、ymin<ymax的程序:
import os
import shutil
import xml.etree.ElementTree as ET
xml_src_dir = "..." #xml原位置
xml_dst_dir = "..." #修改后xml的保存位置
xml_files = os.listdir(xml_src_dir)
for xml_src_file in xml_files:
if xml_src_file.endswith('.xml'):
shutil.copy(xml_src_dir + xml_src_file, xml_dst_dir)
xml_dst = ET.parse(xml_dst_dir + xml_src_file)
for obj in xml_dst.iter('object'):
xmin=float(obj.find('bndbox').find('xmin').text)
ymin=float(obj.find('bndbox').find('ymin').text)
xmax=float(obj.find('bndbox').find('xmax').text)
ymax=float(obj.find('bndbox').find('ymax').text)
xmn = min(xmin,xmax)
ymn = min(ymin,ymax)
xmx = max(xmin,xmax)
ymx = max(ymin,ymax)
xmn = int(xmn)
ymn = int(ymn)
xmx = int(xmx)
ymx = int(ymx)
obj.find('bndbox').find('xmin').text=str(xmn)
obj.find('bndbox').find('ymin').text=str(ymn)
obj.find('bndbox').find('xmax').text=str(xmx)
obj.find('bndbox').find('ymax').text=str(ymx)
xml_dst.write(xml_dst_dir + xml_src_file)