# 载入预训练参数
tvars = tf.trainable_variables()
initialized_variable_names = {}
scaffold_fn = None
if init_checkpoint: #init_checkpoint为模型参数文件
for init_file in init_checkpoint.split(","):
assignment_map, tmp_init_map = get_assignment_map_from_checkpoint(tvars, init_file)
tf.train.init_from_checkpoint(init_file, assignment_map)
initialized_variable_names.update(tmp_init_map)
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
"""Compute the union of the current variables and checkpoint variables."""
initialized_variable_names = {}
name_to_variable = collections.OrderedDict()
for var in tvars:
name = var.name
m = re.match("^(.*):\\d+$", name)
if m is not None:
name = m.group(1)
name_to_variable[name] = var
init_vars = tf.train.list_variables(init_checkpoint)
assignment_map = collections.OrderedDict()
for x in init_vars:
(name, var) = (x[0], x[1])
if name not in name_to_variable:
continue
assignment_map[name] = name
initialized_variable_names[name] = 1
initialized_variable_names[name + ":0"] = 1
return assignment_map, initialized_variable_names