TF1.15.5训练ssd_mobilenet_v1_coco报错tensorflow.python.framework.errors_impl.InvalidArgumentError

问题记录

  • Ubuntu 20.04 LTS
  • nvidia-smi 465.27 cuda11.3
  • tensorflow-gpu 1.15.5
     
tensorflow.python.framework.errors_impl.InvalidArgumentError:  assertion failed: [[0.919921875][0.68359375][0.517578125]...] [[0.814453125][0.587890625][0.42578125]...]
	 [[{{node Assert_1/AssertGuard/Assert}}]]
	 [[IteratorGetNext]]

完整报错如下:

Traceback (most recent call last):
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
    return fn(*args)
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1349, in _run_fn
    return self._call_tf_sessionrun(options, feed_dict, fetch_list,
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1441, in _call_tf_sessionrun
    return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node Dataset_map_transform_and_pad_input_data_fn_433}} assertion failed: [[0.919921875][0.68359375][0.517578125]...] [[0.814453125][0.587890625][0.42578125]...]
	 [[{{node Assert_1/AssertGuard/Assert}}]]
	 [[IteratorGetNext]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "model_main.py", line 108, in <module>
    tf.app.run()
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/haotian/.local/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/haotian/.local/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "model_main.py", line 104, in main
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/training.py", line 473, in train_and_evaluate
    return executor.run()
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/training.py", line 613, in run
    return self.run_local()
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/training.py", line 710, in run_local
    self._estimator.train(
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 370, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1161, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1193, in _train_model_default
    return self._train_with_estimator_spec(estimator_spec, worker_hooks,
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1494, in _train_with_estimator_spec
    _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 750, in run
    return self._sess.run(
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1255, in run
    return self._sess.run(
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1360, in run
    raise six.reraise(*original_exc_info)
  File "/home/haotian/.local/lib/python3.8/site-packages/six.py", line 703, in reraise
    raise value
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1345, in run
    return self._sess.run(*args, **kwargs)
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1413, in run
    outputs = _WrappedSession.run(
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/training/monitored_session.py", line 1176, in run
    return self._sess.run(*args, **kwargs)
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 955, in run
    result = self._run(None, fetches, feed_dict, options_ptr,
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1179, in _run
    results = self._do_run(handle, final_targets, final_fetches,
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1358, in _do_run
    return self._do_call(_run_fn, feeds, fetches, targets, options,
  File "/home/haotian/.local/lib/python3.8/site-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  assertion failed: [[0.919921875][0.68359375][0.517578125]...] [[0.814453125][0.587890625][0.42578125]...]
	 [[{{node Assert_1/AssertGuard/Assert}}]]
	 [[IteratorGetNext]]

 

解决方法

检查数据集标签,尤其是自己做了数据增强的,很可能是xml损坏了

  • 检查 xmin、xmax、ymin、ymax 有没有负值;
     
  • 检查 xmin、xmax、ymin、ymax 是否超出图像范围,可以在生成tfrecords的csv_to_tfrecords.py里面找到如下位置,修改成这样:
for index, row in group.object.iterrows():
        # xmins.append(row['xmin'] / width)
        # xmaxs.append(row['xmax'] / width)
        # ymins.append(row['ymin'] / height)
        # ymaxs.append(row['ymax'] / height)

        xmn = row['xmin'] / width
        if xmn < 0.0:
            xmn = 0.0
        elif xmn > 1.0:
            xmn = 1.0
        xmins.append(xmn)

        xmx = row['xmax'] / width
        if xmx < 0.0:
            xmx = 0.0
        elif xmx > 1.0:
            xmx = 1.0
        xmaxs.append(xmx)

        ymn = row['ymin'] / height
        if ymn < 0.0:
            ymn = 0.0
        elif ymn > 1.0:
            ymn = 1.0
        ymins.append(ymn)

        ymx = row['ymax'] / height
        if ymx < 0.0:
            ymx = 0.0
        elif ymx > 1.0:
            ymx = 1.0
        ymaxs.append(ymx)

        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_text_to_int(row['class']))

 

  • 检查 xmin<xmax、ymin<ymax 是否满足,这里有一段修复xmin<xmax、ymin<ymax的程序:
import os
import shutil
import xml.etree.ElementTree as ET

xml_src_dir = "..."    #xml原位置
xml_dst_dir = "..."    #修改后xml的保存位置

xml_files = os.listdir(xml_src_dir)
for xml_src_file in xml_files:
    if xml_src_file.endswith('.xml'):
        shutil.copy(xml_src_dir + xml_src_file, xml_dst_dir)
        xml_dst = ET.parse(xml_dst_dir + xml_src_file)
        for obj in xml_dst.iter('object'):
            xmin=float(obj.find('bndbox').find('xmin').text)
            ymin=float(obj.find('bndbox').find('ymin').text)
            xmax=float(obj.find('bndbox').find('xmax').text)
            ymax=float(obj.find('bndbox').find('ymax').text)
            xmn = min(xmin,xmax)
            ymn = min(ymin,ymax)
            xmx = max(xmin,xmax)
            ymx = max(ymin,ymax)
            xmn = int(xmn)
            ymn = int(ymn)
            xmx = int(xmx)
            ymx = int(ymx)
            obj.find('bndbox').find('xmin').text=str(xmn)
            obj.find('bndbox').find('ymin').text=str(ymn)
            obj.find('bndbox').find('xmax').text=str(xmx)
            obj.find('bndbox').find('ymax').text=str(ymx)
        xml_dst.write(xml_dst_dir + xml_src_file)
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值