报错
-
NotFoundError: Key bert_1/embeddings/LayerNorm/beta not found in checkpoint
-
OutOfRangeError (see above for traceback): Read less bytes than requested
-
NotFoundError: Key _CHECKPOINTABLE_OBJECT_GRAPH not found in checkpoint
-
NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint.
环境:Google colab下,tensorflow==1.11.0 bert-tensorflow ==1.0.1
代码
import tensorflow as tf
from bert import modeling
import numpy as np
import dataProcessing
words, masks, type_ids = dataProcessing.getX()
# 根据处理的数据得到bert输入
input_ids = tf.placeholder(shape=(None, 512), dtype=tf.int32, name="input_ids")
input_mask = tf.placeholder(shape=(None, 512), dtype=tf.int32, name="input_mask")
segment_ids = tf.placeholder(shape=(None, 512), dtype=tf.int32, name="segment_ids") # 全设置为0
# bert模型配置
bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json")
# 加载bert模型
model = modeling.BertModel(
config=bert_config, # 模型的配置文件
is_training=False, # 参数是否可训练
input_ids=input_ids, # 每个字对应vocab中的id
input_mask=input_mask, # 在有实值token位上为1,padding位上为0
token_type_ids=segment_ids, # segment_ids 为段id,在NSP任务中可用于区分句子
use_one_hot_embeddings=False
)
# 得到bert层的输出 get_sequence_output的输出结果shape为(?,512,768) get_pooled_output输出结果shape(?,768)
embedding = model.get_sequence_output()
# 模型存储位置
checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt"
with tf.Session() as sess:
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.restore(sess, checkpoint)
sess.run(embedding, feed_dict={input_ids: np.asarray(words), input_mask: np.asarray(masks), segment_ids: np.asarray(type_ids)})
一开始查到的可能原因是ckpt文件损坏导致的错误,可重新下载文件后还是报同样的错误。因为这段代码在pc(tensorflow版本为1.14.0)本地运行是没有错误的,怀疑是tensorflow版本问题造成的,重新安装了1.14.0版本,然而还是出现错误。。。
最后按照一个博主的方法,在words, masks, type_ids = dataProcessing.getX()
后
添加了一行代码tf.reset_default_graph()
,运行成功。
猜测原因
tf.session会自动创建并维护一个默认的计算图。而每次加载model进行“预测”都需要建立计算图,同时增加新的节点,故导致变量名、计算图的键冲突。
tf.reset_default_graph()
用于清除默认图形堆栈并重置全局默认图形。