FasterRCNN tensorflow 2.x版本遇到‘NewCheckpointReader’问题

问题描述:使用tensorflow-gpu 2.x版本运行FasterRCNN算法出现,使用tensorflow 1.15 cpu进行训练则不会出现

module ‘tensorflow.python.pywrap_tensorflow’ has no attribute ‘NewCheckpointReader’
if v.name.split(‘:’)[0] in var_keep_dic:
TypeError: argument of type ‘NoneType’ is not iterable
代码提示如下:

Traceback (most recent call last):
  File "E:\SOTA\Object Detection\Faster-RCNN-TensorFlow-Python3-master\Faster-RCNN-TensorFlow-Python3-master\train.py", line 221, in <module>
    train.train()
  File "E:\SOTA\Object Detection\Faster-RCNN-TensorFlow-Python3-master\Faster-RCNN-TensorFlow-Python3-master\train.py", line 127, in train
    variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)
  File "E:\SOTA\Object Detection\Faster-RCNN-TensorFlow-Python3-master\Faster-RCNN-TensorFlow-Python3-master\lib\nets\vgg16.py", line 71, in get_variables_to_restore
module 'tensorflow.python.pywrap_tensorflow' has no attribute 'NewCheckpointReader'
    if v.name.split(':')[0] in var_keep_dic:
TypeError: argument of type 'NoneType' is not iterable

有帖子说是模型存放路径问题,但是在确保模型位置为./data/imagenet_weights/vgg16.ckpt 后仍无法解决。
具体解决办法如下:
步骤1:由于使用tensorflow 2.x版本与tensorflow 1.x版本不兼容,因此需要将所有出现以下代码

import tensorflow as tf

的地方替换为

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

步骤2:提示’tensorflow.python.pywrap_tensorflow’ has no attribute ‘NewCheckpointReader’,由于1.x版本中的函数调用跟2.x不同,因此定位到train.py文件中,找到以下部分

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print("It's likely that your checkpoint file has been compressed "
                      "with SNAPPY.")

reader = pywrap_tensorflow.NewCheckpointReader(file_name)

语句替换成

reader = tf.train.NewCheckpointReader(file_name)

修改后的函数部分如下:

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            # reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            reader = tf.train.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print("It's likely that your checkpoint file has been compressed "
                      "with SNAPPY.")

修改后就可以使用GPU进行模型训练了
遇到image invalid, skipping问题,需要将config.py文件中的roi_bg_threshold_low改成0.0就不会出现这个问题
参考文章
https://blog.csdn.net/knighthood2001/article/details/125840934
https://blog.csdn.net/weixin_41463944/article/details/107979692

  • 6
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值