因为tensorflow的版本问题,所以引用其函数时总是容易报错。
本来用的以下代码:
import tensorflow.compat.v1 as tf1
tf1.disable_v2_behavior()
checkpoint_path = r'D:/INCREASE-main/Beijing/data/INCREASE_SP_pretrained'
# Read data from checkpoint file
reader = tf1.train.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
想要查看一下训练的节点。但是报错。
改为:
import tensorflow as tf1
tf1.compat.v1.disable_v2_behavior()
checkpoint_path = r'D:/INCREASE-main/Beijing/data/INCREASE_SP_pretrained'
# Read data from checkpoint file
reader = tf1.compat.v1.train.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
可以运行了,结果如下: