最近,研究了下如何使用基于tensorflow-hub中预训练bert,一开始找到的关于预模型使用介绍的官方教程国内打不开,所以看了很多博客遇到了很多坑,直至最后找到能打开的教程,才发现使用很简单。
实验版本:
tensorflow版本: 2.3.0
tensorflow-hub版本:0.9.0
python版本: 3.7.6
数据准备:
首先,熟悉bert的都知道输入有3个:inputIds、inputMask、segmentIds,这个不多说了,百度一大堆。
直接获取bert输出代码:
max_seq_length = 256
input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,),
dtype=tf.int32,name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,),
dtype=tf.int32,name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,),
dtype=tf.int32,name="segment_ids")
# 将trainable值改为False
module = hub.KerasLayer(BERT_URL,trainable=False)#,signature="token")
pooled_output, sequence_output = module([input_mask,segment_ids,input_word_ids])
# 构建模型输入输出
model = tf.keras.Model(inputs=[input_word_ids,input_mask,segment_ids],outputs=[pooled_output,sequence_output])
# 获取输出
output = model.predict([inputIds,inputMask,segmentIds])
# output输出结果 ----》 pool_out: shape=[batch, 768];sequence_out: shape=[batch, 256, 768]
-------------------------------------------------BUG----------------------------------------------
这里也尝试了参考链接3中博客方式获取bert输出结果,但是遇到个问题
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (2 total):
* False
* None:
# 实验内容1——参数名来自https://hub.tensorflow.google.cn/tensorflow/bert_zh_L-12_H-768_A-12/2
outputs,_ = hub_module(input_word_ids=tf.constant(tmp_inputids),
input_mask=tf.constant(tmp_inputMask),
segment_ids=tf.constant(tmp_segmentIds))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-45-b1533b83b191> in <module>
2 outputs,_ = hub_module(input_word_ids=tf.constant(tmp_inputids),
3 input_mask=tf.constant(tmp_inputMask),
----> 4 segment_ids=tf.constant(tmp_segmentIds))
5
6 # # 实验内容2——参数名来自报错提示
/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
507
508 def _call_attribute(instance, *args, **kwargs):
--> 509 return instance.__call__(*args, **kwargs)
510
511
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
781
782 new_tracing_count = self._get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
812 # In this case we have not created variables on the first call. So we can
813 # run the first trace but we should fail if variables are created.
--> 814 results = self._stateful_fn(*args, **kwds)
815 if self._created_variables:
816 raise ValueError("Creating variables on a non-first call to a function"
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager