tensorflow中做迁移学习时,会遇到模型参数值与网络模型参数定义不一致的问题,而无法载入模型参数。如参数名字改变,模型参数shape不一致,直接调用saver.restore()会报错。需要手动修改这个restore()函数。
这里提供两种恢复方式:
(1)第一种方法,提取出ckpt中参数与模型中的参数名字一致,shape大小也一致,然后再调用saver.restore()恢复。
restore_1(session,save_file):
ckpt = tf.train.get_checkpoint_state(save_file)
reader =tf.train.NewCheckpointReader(ckpt.model_checkpoint_path)
saved_shapes = reader.get_variable_to_shape_map()
var_names = sorted([(var.name, var.name.split(‘:’)[0])
for var in tf.global_variables()
if var.name.split(':')[0] in saved_shapes])
with tf.variable_scope(‘’, reuse=True):
for var_name,saved_var_name in var_names:
curr_var = name2var[saved_var_name]
var_shape = curr_var.get_shape().as_list()
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
saver = tf.train.Saver(restore_vars)
saver.restore(session, ckpt.model_checkpoint_path)
(2)第二种手动恢复参数,相较第一种灵活性更高。只选取待恢复的参数,同时可对ckpt中的参数做修改。
restore_2(session, save_file):
ckpt = tf.train.get_checkpoint_state(save_file)
reader = tf.train.NewCheckpointReader(ckpt.model_checkpoint_path)
saved_shapes = reader.get_variable_to_shape_map()
var_names1 = sorted(set([str((var.name.split(':')[0]).split('/')[0])
for var in tf.global_variables()
if var.name.split(':')[0] in saved_shapes]))
var_names2 = sorted(set([str((var.name.split(':')[0]).split('/')[1])
for var in tf.global_variables()
if var.name.split(':')[0] in saved_shapes]))
for key in var_names1:
with tf.variable_scope(key, reuse=True):
for key1 in var_names2:
VAR = tf.get_variable(key1)
sess.run(tf.assign(VAR, reader.get_tensor(key+'/'+key1)))
except ValueError:
print('ignore:',key)