Transformer onnx模型的导出

0. Encoder.onnx 和decoder.onnx

        Seq2seq结构也称为encoder-decoder结构,在decoder结构为单步解码时,seq2seq的导出的与只有encoder时(如BERT等)没有差别。当decoder为多步解码时,如生成式任务中的机器翻译、文本摘要等,由于decoder需要多次用到,问题变得稍稍复杂了一点。

        由于onnx对多步解码的支持较差,我们选择将模型拆分成encoder和decoder两部分。(官网上有torch.jit.script(model),但对transformers之外的框架不一定适用,所以我们讨论更一般的情况。)

        在常见的基于transformer机制的框架中,一般都是先定义各层中的计算机制,如linear,activation (relu, gelu), scaled_multi_head_attention等,再定义各层(layer, T5中的block)的计算,然后再定义encoder, decoder,最后定义整个transformer及对应的train和inferece各状态对应的输入输出。

        在定义encoder和decoder的时候,分别加入了src和tgt embedding来进行初始化,有时使用shared embedding, 即src embedding和tgt embedding相同。Encoder和decoder的最大不同在于,encoder中只涉及self-attention的计算,decoder先进行self-attention的计算,再进行cross_attention的计算。

### 将Transformer模型转换为ONNX格式 为了将Transformer模型转换成ONNX格式,可以利用`optimum.onnxruntime`库来简化这一过程。通过安装必要的软件包如`transformers`, `optimum[onnxruntime]`以及`onnx`[^1],能够方便地完成模型导出工作。 下面是一个简单的Python脚本例子展示如何实现这一点: ```python from transformers import AutoModelForSequenceClassification, AutoTokenizer from optimum.onnxruntime import ORTModelForSequenceClassification import torch model_name = "distilbert-base-uncased-finetuned-sst-2-english" tokenizer = AutoTokenizer.from_pretrained(model_name) pytorch_model = AutoModelForSequenceClassification.from_pretrained(model_name) # 导出ONNX格式 onnx_path = "./onnx/" ort_model = ORTModelForSequenceClassification.from_pretrained(pytorch_model, export=True, from_transformers=True) ort_model.save_pretrained(onnx_path) ``` 这段代码首先加载了一个预训练好的PyTorch版本的Transformers模型及其对应的分词器;接着创建了一个`ORTModelForSequenceClassification`实例并调用了其`export()`方法来进行实际的转换操作;最后保存了转换后的ONNX模型文件至指定路径下。 对于想要使用已经存在的ONNX模型进行推理的情况,则可以直接加载该模型而无需重新编译或转换: ```python from optimum.onnxruntime import ORTModelForSequenceClassification from transformers import AutoTokenizer model_name_or_path = "<path_to_onnx_directory>" tokenizer = AutoTokenizer.from_pretrained("<original_model>") ort_model = ORTModelForSequenceClassification.from_pretrained(model_name_or_path, file_name="model.onnx") inputs = tokenizer("This is a sample sentence.", return_tensors="pt") outputs = ort_model(**inputs) print(outputs.logits) ``` 在这个过程中,只需提供之前保存过的ONNX目录位置作为参数传递给`ORTModelForSequenceClassification.from_pretrained()`函数即可轻松加载模型,并像平常一样执行前向传播计算得到预测结果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值