让大家久等了,BERT推理加速终于开源了

前几个月一直有不少小伙伴问我要「LightSeq的BERT推理加速代码」,当时内部已经使用了,但是一直没空整理开源。

现在代码终于整理好了,写了一个简单的样例,大家有需要的可以使用起来了。

实现原理

这里我直接使用预训练好的BERT模型,用户只需要输入一个带有[MASK]标记的句子,就可以自动预测出完整的句子。

例如我输入“巴黎是[MASK]国的首都”,那么模型就会输出“巴黎是法国的首都。”。

LightSeq已经「完美支持了BERT模型的快速推理」,代码近期已经开源:

GitHub - bytedance/lightseq: LightSeq: A High Performance Library for Sequence Processing and Generation

BERT推理使用样例可以参考examples/inference/python目录下的ls_bert.py文件。我们用LightSeq来加速BERT推理试试。

首先需要安装LightSeq和Hugging Face:

pip install lightseq transformers

然后需要将Hugging Face的BERT模型导出为LightSeq支持的HDF5模型格式,运行examples/inference/python目录下的hf_bert_export.py文件即可,运行前将代码的第167-168两行修改为下面这样,指定使用中文版本的BERT预训练模型。

output_lightseq_model_name = "lightseq-bert-base-chinese"
input_huggingface_bert_model = "bert-base-chinese"

然后就会在运行目录下生成一个lightseq-bert-base-chinese.hdf5模型文件,导出就成功啦。

最后使用LightSeq进行推理即可:

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import lightseq.inference as lsi

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
hf_model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese")
hf_model.to("cuda:0")
ls_model = lsi.Bert("lightseq-bert-base-chinese.hdf5", 128)

while True:
    raw_text = input("请输入中文句子,要预测的字符用#代替:\n> ")
    input_text = raw_text.replace("#", "[MASK]")
    inputs = tokenizer(input_text, return_tensors="pt")
    input_ids = inputs["input_ids"]
    mask = inputs["attention_mask"]

    outputs = ls_model.infer(input_ids, mask)
    logits = hf_model.cls(torch.Tensor(outputs).to(dtype=torch.float, device="cuda:0"))
    output_ids = logits.argmax(axis=2)
    res_text = tokenizer.batch_decode(output_ids)

    res_text = res_text[0][1:-1].replace(" ", "")
    output_text = list(raw_text)
    for i in range(len(raw_text)):
        if raw_text[i] == "#":
            output_text[i] = res_text[i]
    print("> " + "".join(output_text))

效果演示

给大家看看效果,运行我写好的代码,我们来看看会输出什么结果:

请输入中文句子,要预测的字符用#代替:
> 巴黎是#国的首都。
> 巴黎是法国的首都。

代码地址

GitHub - bytedance/lightseq: LightSeq: A High Performance Library for Sequence Processing and Generation

就在上周,首位外部贡献者出现了,修复了LightSeq的词嵌入表示的bug。

v2-867377af0406e198199342d256303c70_b.jpg
在这里我们非常欢迎感兴趣的同学来贡献自己的代码,包括但不局限于:修复bug、提供训练和推理样例、支持更多模型结构。
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

算法码上来

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值