抽取bert某几层参数保存

 

import tensorflow as tf
import os

sess = tf.Session()
last_name = 'bert_model.ckpt'
model_path = 'chinese_L-12_H-768_A-12'
imported_meta = tf.train.import_meta_graph(os.path.join(model_path, last_name + '.meta'))
imported_meta.restore(sess, os.path.join(model_path, last_name))
init_op = tf.local_variables_initializer()
sess.run(init_op)

bert_dict = {}
# 获取待保存的层数节点
for var in tf.global_variables():
    # print(var)
    # 提取第0层和第11层和其它的参数,其余1-10层去掉,存储变量名的数值
    if var.name.startswith('bert/encoder/layer_') and not var.name.startswith(
            'bert/encoder/layer_0') and not var.name.startswith('bert/encoder/layer_11'):
        pass
    else:
        bert_dict[var.name] = sess.run(var).tolist()

# print('bert_dict:{}'.format(bert_dict))
# 真实保存的变量信息
need_vars = []
for var in tf.global_variables():
    if var.name.startswith('bert/encoder/layer_') and not var.name.startswith(
            'bert/encoder/layer_0/') and not var.name.startswith('bert/encoder/layer_1/'):
        pass
    elif var.name.startswith('bert/encoder/layer_1/'):
        # 寻找11层的var name,将11层的参数给第一层使用
        new_name = var.name.replace("bert/encoder/layer_1", "bert/encoder/layer_11")
        op = tf.assign(var, bert_dict[new_name])
        sess.run(op)
        need_vars.append(var)
        print(var)
    else:
        need_vars.append(var)
        print('####',var)

# 保存model
saver = tf.train.Saver(need_vars)
saver.save(sess, os.path.join('chinese_L-12_H-768_A-12_pruning', 'bert_pruning_2_layer.ckpt'))

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

samoyan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值