部署需要把模型压缩成pb文件,但是需要一些前期准备
首先需要生成label2id.pkl 这个文件,需要添加代码在run_classifier.py或自己的运行任务文件中
#---------
import pickle
output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
if not os.path.exists(output_label2id_file):
with open(output_label2id_file, 'wb') as w:
pickle.dump(label_map, w)
#---------
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
if isinstance(example, PaddingInputExample):
return InputFeatures(
input_ids=[0] * max_seq_length,
input_mask=[0] * max_seq_length,
segment_ids=[0] * max_seq_length,
label_id=0,
is_real_example=False)
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
#---------
import pickle
output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
if not os.path.exists(output_label2id_file):
with open(output_label2id_file, 'wb') as w:
pickle.dump(label_map, w)
#---------
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
然后需要创建 freeze_graph.py 这个文件(代码放在最后)
然后运行 freeze_graph.py
bert_model_dir 是训练完模型的路径
model_dir 是输出pb模型的路径
注意参数和之前统一
然后会看到生成了 .pd 文件
有的时候报错是因为需要 bert_config.json 和 vocab 两个文件,这个在下载的BERT模型文件夹里面可以找到
例 chinese_L-12_H-768_A-12 里面就有,BERT官方链接 https://github.com/google-research/bert/
python free