将 GPT-2 导出为 ONNX 模型:部署加速第一步

本文手把手讲解如何将微调后的 GPT-2 模型导出为 ONNX 格式,解决常见导出错误,封装自定义 Wrapper 支持动态 shape,铺平后续 TensorRT 加速与跨平台部署之路。


📦 为什么要导出 ONNX?

ONNX(Open Neural Network Exchange)是一种通用模型交换格式,可以将 PyTorch、TensorFlow 等训练好的模型导出后,在:

  • ✅ TensorRT 上做高性能推理
  • ✅ ONNX Runtime 实现跨平台部署
  • ✅ WebAssembly、移动端等环境中部署

1️⃣ 环境准备

确保你已具备以下条件:

  • 已完成 GPT-2 微调(如使用 Hugging Face)
  • 安装了 transformerstorchonnx

导出前请准备好本地训练模型目录,如:

../python1_basic_training/gpt2_finetune/
├── config.json
├── model.safetensors
├── tokenizer_config.json
├── vocab.json
├── merges.txt

2️⃣ 加载模型与分词器

from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

model_path = "../python1_basic_training/gpt2_finetune"
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)
model.eval()

注意:GPT-2 默认启用了 use_cache=True,用于快速生成,但 ONNX 不支持这种动态缓存结构,必须关闭


3️⃣ 封装 Wrapper 类,禁用 use_cache

class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.config = model.config
        self.config.use_cache = False  # 显式关闭缓存

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
        return outputs.logits

Wrapper 类的作用:

  • 屏蔽 GPT-2 原始 forward 中的 past_key_values
  • 明确指定导出输出为 logits

4️⃣ 构造 dummy 输入

text = "Hello world"
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
  • ONNX 是静态图 → 必须提供示例输入,告知模型输入维度与结构
  • 本例中输入 shape 为 (1, N),支持后续导出为 动态长度模型

5️⃣ 导出为 ONNX 文件

wrapped_model = Wrapper(model)
onnx_path = "model/gpt2.onnx"

torch.onnx.export(
    wrapped_model,
    (input_ids, attention_mask),
    onnx_path,
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch", 1: "sequence"},
        "attention_mask": {0: "batch", 1: "sequence"},
        "logits": {0: "batch", 1: "sequence"}
    },
    opset_version=14  # ✅ 建议使用 >=13 的版本,支持 FlashAttention 等操作
)

导出成功后终端将输出:

✅ 模型成功导出为 ONNX:model/gpt2.onnx

🧪 验证导出模型结构

你可以使用 netron 查看结构是否正确:

pip install netron
netron model/gpt2.onnx

确认 input_idsattention_masklogits 的动态 shape 是否正确映射。


❗ 常见报错与排查建议

报错信息原因与解决方式
RuntimeError: export with use_cache=True not supported✅ 使用 Wrapper 禁用 use_cache
ONNX export failed: Can't export function 'forward'检查模型是否有自定义参数未处理
opset_version too low使用 opset_version=14 或更高

📌 总结

  • ONNX 是后续 TensorRT / ONNXRuntime 加速部署的必要前提
  • GPT-2 导出必须封装 Wrapper 禁用 use_cache
  • 设置 dynamic_axes 支持动态 batch / sequence 推理
  • 对应项目 gpt2-trt-deploy 的 export_to_onnx.py
  • 下一篇将继续介绍如何基于此模型进行 TensorRT 加载与推理!

📎 本文为 GPT-2 项目加速部署系列第一篇


🧭 本系列 GPT-2 项目加速部署系列五部曲


📌 YoanAILab 技术导航页

💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页

📚 包含内容:

  • 🧠 GPT-2 项目源码(GitHub)
  • ✍️ CSDN 技术专栏合集
  • 💼 知乎转型日志
  • 📖 公众号 YoanAILab 全文合集
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yoan AI Lab

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值