BERT-TensorFlow版微调后模型过大解决方案

背景

TensorFlow中加载预训练的BERT模型(base),在下游微调后发现最终模型比原始模型大许多。
微调后的模型
原始用来初始化的BERT模型
对于某些竞赛,要求提交的代码、模型和数据文件有容量限制,为此需要尽量在模型上瘦身。

分析和解决

预训练的BERT模型只有390MB,但是微调结果模型有1.2GB。这是由于训练过程中的checkpoints由于包含了每个权重变量对应的Adam momentumvariance变量,所以训练后的checkpoints是分布式checkpoint的3倍。这多出来的Adam momentum和variance 实际上不是模型的一部分,其作用是能够暂停并在中途恢复训练。

解决方案1:

将模型转为pb格式,然后在预测时候加载pb格式的模型进行推理操作。这里就不过多阐述这个方案。

解决方案2:

直接将模型中多余的变量去掉,实现模型瘦身。具体实现脚本如下所示:

import tensorflow as tf
meta_file = "./online_models/model.ckpt-20061.meta"
checkpoint_file = "./online_models/model.ckpt-20061"
sess = tf.Session()
imported_meta = tf.train.import_meta_graph(meta_file)
imported_meta.restore(sess, checkpoint_file)
my_vars = []
for var in tf.all_variables():
    if 'adam_v' not in var.name and 'adam_m' not in var.name:
        my_vars.append(var)
saver = tf.train.Saver(my_vars)
saver.save(sess, './online_models/divorce_best_model.ckpt')
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值