问题描述:使用tensorflow-gpu 2.x版本运行FasterRCNN算法出现,使用tensorflow 1.15 cpu进行训练则不会出现
module ‘tensorflow.python.pywrap_tensorflow’ has no attribute ‘NewCheckpointReader’
if v.name.split(‘:’)[0] in var_keep_dic:
TypeError: argument of type ‘NoneType’ is not iterable
代码提示如下:
Traceback (most recent call last):
File "E:\SOTA\Object Detection\Faster-RCNN-TensorFlow-Python3-master\Faster-RCNN-TensorFlow-Python3-master\train.py", line 221, in <module>
train.train()
File "E:\SOTA\Object Detection\Faster-RCNN-TensorFlow-Python3-master\Faster-RCNN-TensorFlow-Python3-master\train.py", line 127, in train
variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)
File "E:\SOTA\Object Detection\Faster-RCNN-TensorFlow-Python3-master\Faster-RCNN-TensorFlow-Python3-master\lib\nets\vgg16.py", line 71, in get_variables_to_restore
module 'tensorflow.python.pywrap_tensorflow' has no attribute 'NewCheckpointReader'
if v.name.split(':')[0] in var_keep_dic:
TypeError: argument of type 'NoneType' is not iterable
有帖子说是模型存放路径问题,但是在确保模型位置为./data/imagenet_weights/vgg16.ckpt 后仍无法解决。
具体解决办法如下:
步骤1:由于使用tensorflow 2.x版本与tensorflow 1.x版本不兼容,因此需要将所有出现以下代码
import tensorflow as tf
的地方替换为
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
步骤2:提示’tensorflow.python.pywrap_tensorflow’ has no attribute ‘NewCheckpointReader’,由于1.x版本中的函数调用跟2.x不同,因此定位到train.py文件中,找到以下部分
def get_variables_in_checkpoint_file(self, file_name):
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
return var_to_shape_map
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
将
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
语句替换成
reader = tf.train.NewCheckpointReader(file_name)
修改后的函数部分如下:
def get_variables_in_checkpoint_file(self, file_name):
try:
# reader = pywrap_tensorflow.NewCheckpointReader(file_name)
reader = tf.train.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
return var_to_shape_map
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
修改后就可以使用GPU进行模型训练了
遇到image invalid, skipping问题,需要将config.py文件中的roi_bg_threshold_low改成0.0就不会出现这个问题
参考文章
https://blog.csdn.net/knighthood2001/article/details/125840934
https://blog.csdn.net/weixin_41463944/article/details/107979692