NotFoundError: Key bert_1/embeddings/LayerNorm/beta not found in checkpoint

报错

  • NotFoundError: Key bert_1/embeddings/LayerNorm/beta not found in checkpoint

  • OutOfRangeError (see above for traceback): Read less bytes than requested

  • NotFoundError: Key _CHECKPOINTABLE_OBJECT_GRAPH not found in checkpoint

  • NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint.

环境:Google colab下,tensorflow==1.11.0 bert-tensorflow ==1.0.1

代码

import tensorflow as tf
from bert import modeling
import numpy as np
import dataProcessing

words, masks, type_ids = dataProcessing.getX()

# 根据处理的数据得到bert输入
input_ids = tf.placeholder(shape=(None, 512), dtype=tf.int32, name="input_ids")   
input_mask = tf.placeholder(shape=(None, 512), dtype=tf.int32, name="input_mask")
segment_ids = tf.placeholder(shape=(None, 512), dtype=tf.int32, name="segment_ids")  # 全设置为0

# bert模型配置
bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json")

# 加载bert模型
model = modeling.BertModel(
    config=bert_config,   # 模型的配置文件
    is_training=False,  # 参数是否可训练
    input_ids=input_ids,  # 每个字对应vocab中的id
    input_mask=input_mask,  # 在有实值token位上为1,padding位上为0
    token_type_ids=segment_ids,  # segment_ids 为段id,在NSP任务中可用于区分句子
    use_one_hot_embeddings=False
)

# 得到bert层的输出 get_sequence_output的输出结果shape为(?,512,768) get_pooled_output输出结果shape(?,768)
embedding = model.get_sequence_output()

# 模型存储位置
checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt"
with tf.Session() as sess:
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, checkpoint)
    sess.run(embedding, feed_dict={input_ids: np.asarray(words), input_mask: np.asarray(masks), segment_ids: np.asarray(type_ids)})

一开始查到的可能原因是ckpt文件损坏导致的错误,可重新下载文件后还是报同样的错误。因为这段代码在pc(tensorflow版本为1.14.0)本地运行是没有错误的,怀疑是tensorflow版本问题造成的,重新安装了1.14.0版本,然而还是出现错误。。。
最后按照一个博主的方法,在words, masks, type_ids = dataProcessing.getX()
添加了一行代码tf.reset_default_graph(),运行成功。

猜测原因

tf.session会自动创建并维护一个默认的计算图。而每次加载model进行“预测”都需要建立计算图,同时增加新的节点,故导致变量名、计算图的键冲突。
tf.reset_default_graph()用于清除默认图形堆栈并重置全局默认图形。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值