bert 中文 代码 谷歌_最近应用bert做ner的一点点总结

本文分享了一个将BERT模型应用于英文命名实体识别(NER)的实验过程,包括如何克服资源限制、解决技术难题,最终实现了良好的识别效果。

今天下班了,想着有好多天没更新文章了,就想着把最近做的一点实验给记录下来,实验是关于bert应用于英文命名实体识别的。认真看完,有吐槽也有搞笑。

7dbf5e183c9f396c61984f015dda9e88.png

1   背景

前段时间我在家里办公了小久,被安排去研究bert用于ner识别,心里很是咯噔了一下,心想这么大的模型,我那小破台式机才1060GTX 4G显存,我咋去跑bert这么大的模型,然而公司现在也没机器可用,要是我有块Tesla V100 32G(舔舔口水)该有多好 ! 后来想了想,那就把bert层的参数固定住,不参与训练,反正这么多大佬都说bert预训练词向量的表达已经是很有效的信息了。

2  过程

这么大的项目当然不可能从头开始写撒,于是在github上找了几个bert ner相关的项目,废了九牛二虎之力,终于把bert 层的权重给固定住了。那就开始实验吧,发现这几个项目都是各种报错,这里是要吐槽下,tf1.x确实在调试方面很难用,不过我对torch的工程化一直很嫌弃,所以一直也是在坚持使用和学习tf1.x版本。最后还是借鉴了github上的一个项目,因为就它看起来最靠谱:https://github.com/macanv/BERT-BiLSTM-CRF-NER

它是在官方公布的bert源码上进行了封装,采用的是tf estimator + model_fn回调函数进行封装训练,顺便说一下tf.estimator也是官方极力推荐的tf1.x版本的训练方法。

不过这个项目是做中文NER识别的,因为中文数据经过训练和测试效果都没啥问题,就决定在上面改造成英文NER,改造了下数据输入和最后阶段标签还原模块,因为英文中存在一个wordpiece问题。

插曲,吐槽1min

遇到了不止一个人,把github上面做文本分类或者实体识别的项目,拿过来直接替换成自己的英文数据和英文词库,然后就开始搞起来,然后一堆报错,就开始吐槽人家代码有问题,人家能开源出来一般还是写的很不错的吧,典型的缺乏科学敬畏精神,自己也不看看原来是用来处理中文的,英文和中文数据的前后处理流程能一样吗?Github上的开源项目大多是论文研究中源代码,大多还是外国佬的,一般拿过来在工程中使用多多少少都要做修改,也就是谷歌这种大公司开源出来的项目需要做的改动会比较少,这里对无脑拿来主义真是无可奈何,自己想去吧。

好了,回归正题,上面说到改完英文数据前后处理流程代码后,就开始训练,然后打开了tensorboard每500步骤默默监测损失,发现损失一直降不下去,有点小郁闷,我就纳闷不应该啊,就算我把前面的bert层给固定住了,后面还有一个全连接+CRF呢,参数也不少,也在更新,而且bert预训练模型本身对词的表达已经是相当到位了。后来检查发现,原来在CRF解码之前,代码中使用了tanh激活,这就是问题所在了,相当于是把CRF解码前的结果给限定到了【-1,1】之间了,那Verterbi解码当然要出问题,后来在issue中找到了相同疑问,原来作者忘记改了。找到了问题立马又开始训练,然后我的4G 小显存,电脑风扇呼啦啦的就转起来了,真怕烧坏了,大夏天的。好在只训练后面参数,4小时就训练完了5轮,损失也降下去了。

5fd9a97d6d2f1a0c25cc5bf0eb8a2f9d.png

最后预测了在CoNLL-2003测试集上的结果,实体识别Precision/Recall/F1

分数都达到了85%。这里总结一下,如果是小数据量不建议用bert fine tune,我这里数据不多也不少,有8000多个英文句子,只是因为我没有显存才固定了bert层参数,并不是我想固定住bert权重。为啥呢,因为数据量太少,bert层的参数随着更新可能会过拟合,丢失了泛化学习特征,小数据用 bilstm+crf效果已经很不错了。

3  部署

最后当然是要把训练好的模型拿来部署啦。这里还是首推tensorflow serving。需要在docker中完成部署,首先是要把训练好的ckpt模型转化成saved model pb,也即tf serving支持的模型格式,就是profbuffer。写了部分代码,终于生成了pb,起了tensorflow/serving容器打好了镜像,再启动

docker run -itd -p 8500:8500 -p 8501:8501 tensorflow/serving:bert_ner.

然后就是写客户端请求文件了,上面启动容器的时候暴露了2个调用端口,一个是grpc 8500 ,一个是 rest api 8501。这是两种不同的调用方式,部分调用核心代码如下:

(1)rest api 方式 

    调用很简洁,但是速度跟不上grpc

def bert_ner(sentence,tokenizer ):    def convert(line):        feature = convert_single_example(0, line, label_list, 128, tokenizer, 'p')        input_ids = np.reshape([feature.input_ids], (batch_size,128))        input_mask = np.reshape([feature.input_mask], (batch_size, 128))        segment_ids = np.reshape([feature.segment_ids], (batch_size, 128))        label_ids = np.reshape([feature.label_ids], (batch_size, 128))        return input_ids, input_mask, segment_ids, label_ids    # 处理输入    sentence_token = tokenizer.tokenize(sentence)    print('your input is:{}'.format(sentence_token))    input_ids, input_mask, segment_ids, label_ids = convert(sentence_token)    input_ids_list = input_ids.tolist()    input_mask_list = input_mask.tolist()    url = 'http://127.0.0.1:8501/v1/models/bert_ner/:predict'    data = json.dumps(            {                    "name": 'bert_ner',                    "signature_name":'result',                    "inputs":{                            'input_ids': input_ids_list,                            'input_mask': input_mask_list}})    result = requests.post(url, data=data).json()    return result

(2)grpc方式

调用略微复杂

tf.app.flags.DEFINE_string('server', 'localhost:8500', 'PredictionService host:port')FLAGS = tf.app.flags.FLAGSdef main(sentence, tokenizer):    # host, port = FLAGS.server.split(':')    # channel = implementations.insecure_channel(host, int(port))    # stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)    channel = grpc.insecure_channel(FLAGS.server)    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)    #  Send request    def convert(line):        feature = convert_single_example(0, line, label_list, 128, tokenizer, 'p')        input_ids = np.reshape([feature.input_ids], (batch_size,128))        input_mask = np.reshape([feature.input_mask], (batch_size, 128))        segment_ids = np.reshape([feature.segment_ids], (batch_size, 128))        label_ids = np.reshape([feature.label_ids], (batch_size, 128))        return input_ids, input_mask, segment_ids, label_ids    # 处理输入    sentence_token = tokenizer.tokenize(sentence)    print('your input is:{}'.format(sentence_token))    input_ids, input_mask, segment_ids, label_ids = convert(sentence_token)    # input_ids = np.expand_dims(input_ids, axis=0)    # input_mask = np.expand_dims(input_mask, axis=0)    request = predict_pb2.PredictRequest()    request.model_spec.name = 'bert_ner' # 这个name跟tensorflow_model_server --model_name="bert_ner" 对应    request.model_spec.signature_name = 'result' # 这个signature_name 跟signature_def_map 对应    request.inputs['input_ids'].CopyFrom(         tf.contrib.util.make_tensor_proto(input_ids, shape=[input_ids.shape[0], input_ids.shape[1]])) # shape跟 keras的model.input类型对应    request.inputs['input_mask'].CopyFrom(         tf.contrib.util.make_tensor_proto(input_mask, shape=[input_mask.shape[0], input_mask.shape[1]]))    result_future = stub.Predict(request, 10.0) # 10 secs timeout    # print("result_future",result_future)    response1 = np.array(result_future.outputs['pred_label'].int_val)    print("label_outcome: ", response1)    print(len(response1))

6d48206e19c03790f0c5d50eec9f95cc.png

精选资源
BERT_NER
02-21
基于BERTNER模型,集成了LSTM,CRF网络结构,FGM,EMA等提分trick,还有单独的MRC框架LSTM (Long Short-Term Memory) 是一种特殊的循环神经网络(RNN)架构,用于处理具有长期依赖关系的序列数据。传统的RNN在处理长序列时往往会遇到梯度消失或梯度爆炸的问题,导致无法有效地捕捉长期依赖。LSTM通过引入门控机制(Gating Mechanism)和记忆单元(Memory Cell)来克服这些问题。 以下是LSTM的基本结构和主要组件: 记忆单元(Memory Cell):记忆单元是LSTM的核心,用于存储长期信息。它像一个传送带一样,在整个链上运行,只有一些小的线性交互。信息很容易地在其上保持不变。 输入门(Input Gate):输入门决定了哪些新的信息会被加入到记忆单元中。它由当前时刻的输入和上一时刻的隐藏状态共同决定。 遗忘门(Forget Gate):遗忘门决定了哪些信息会从记忆单元中被丢弃或遗忘。它也由当前时刻的输入和上一时刻的隐藏状态共同决定。 输出门(Output Gate):输出门决定了哪些信息会从记忆单元中输出到当前时刻的隐藏状态中。同样地,它也由当前时刻的输入和上一时刻的隐藏状态共同决定。 LSTM的计算过程可以大致描述为: 通过遗忘门决定从记忆单元中丢弃哪些信息。 通过输入门决定哪些新的信息会被加入到记忆单元中。 更新记忆单元的状态。 通过输出门决定哪些信息会从记忆单元中输出到当前时刻的隐藏状态中。 由于LSTM能够有效地处理长期依赖关系,它在许多序列建模任务中都取得了很好的效果,如语音识别、文本生成、机器翻译、时序预测等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值