BERT-TensorFlow版微调后模型过大解决方案

背景

TensorFlow中加载预训练的BERT模型(base),在下游微调后发现最终模型比原始模型大许多。
微调后的模型
原始用来初始化的BERT模型
对于某些竞赛,要求提交的代码、模型和数据文件有容量限制,为此需要尽量在模型上瘦身。

分析和解决

预训练的BERT模型只有390MB,但是微调结果模型有1.2GB。这是由于训练过程中的checkpoints由于包含了每个权重变量对应的Adam momentumvariance变量,所以训练后的checkpoints是分布式checkpoint的3倍。这多出来的Adam momentum和variance 实际上不是模型的一部分,其作用是能够暂停并在中途恢复训练。

解决方案1:

将模型转为pb格式,然后在预测时候加载pb格式的模型进行推理操作。这里就不过多阐述这个方案。

解决方案2:

直接将模型中多余的变量去掉,实现模型瘦身。具体实现脚本如下所示:

import tensorflow as tf
meta_file = "./online_models/model.ckpt-20061.meta"
checkpoint_file = "./online_models/model.ckpt-20061"
sess = tf.Session()
imported_meta = tf.train.import_meta_graph(meta_file)
imported_meta.restore(sess, checkpoint_file)
my_vars = []
for var in tf.all_variables():
    if 'adam_v' not in var.name and 'adam_m' not in var.name:
        my_vars.append(var)
saver = tf.train.Saver(my_vars)
saver.save(sess, './online_models/divorce_best_model.ckpt')
评论 5 您还未登录,请先 登录 后发表或查看评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:大白 设计师:CSDN官方博客 返回首页

打赏作者

JasonLiu1919

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值