# 获取模型中所有的训练参数。
tvars = tf.trainable_variables()
# 加载BERT模型
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, pm.init_checkpoint)
tf.train.init_from_checkpoint(pm.init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
# 打印加载模型的参数
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
session = tf.Session()
session.run(tf.global_variables_initializer())
获取预训练bert模型中所有的训练参数
最新推荐文章于 2023-04-11 20:15:35 发布