本文手把手讲解如何将微调后的 GPT-2 模型导出为 ONNX 格式,解决常见导出错误,封装自定义 Wrapper 支持动态 shape,铺平后续 TensorRT 加速与跨平台部署之路。
📦 为什么要导出 ONNX?
ONNX(Open Neural Network Exchange)是一种通用模型交换格式,可以将 PyTorch、TensorFlow 等训练好的模型导出后,在:
- ✅ TensorRT 上做高性能推理
- ✅ ONNX Runtime 实现跨平台部署
- ✅ WebAssembly、移动端等环境中部署
1️⃣ 环境准备
确保你已具备以下条件:
- 已完成 GPT-2 微调(如使用 Hugging Face)
- 安装了
transformers
、torch
、onnx
导出前请准备好本地训练模型目录,如:
../python1_basic_training/gpt2_finetune/
├── config.json
├── model.safetensors
├── tokenizer_config.json
├── vocab.json
├── merges.txt
2️⃣ 加载模型与分词器
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
model_path = "../python1_basic_training/gpt2_finetune"
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)
model.eval()
注意:GPT-2 默认启用了 use_cache=True
,用于快速生成,但 ONNX 不支持这种动态缓存结构,必须关闭。
3️⃣ 封装 Wrapper 类,禁用 use_cache
class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = model.config
self.config.use_cache = False # 显式关闭缓存
def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
return outputs.logits
✅
Wrapper
类的作用:
- 屏蔽 GPT-2 原始 forward 中的
past_key_values
- 明确指定导出输出为
logits
4️⃣ 构造 dummy 输入
text = "Hello world"
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
- ONNX 是静态图 → 必须提供示例输入,告知模型输入维度与结构
- 本例中输入 shape 为
(1, N)
,支持后续导出为 动态长度模型
5️⃣ 导出为 ONNX 文件
wrapped_model = Wrapper(model)
onnx_path = "model/gpt2.onnx"
torch.onnx.export(
wrapped_model,
(input_ids, attention_mask),
onnx_path,
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"logits": {0: "batch", 1: "sequence"}
},
opset_version=14 # ✅ 建议使用 >=13 的版本,支持 FlashAttention 等操作
)
导出成功后终端将输出:
✅ 模型成功导出为 ONNX:model/gpt2.onnx
🧪 验证导出模型结构
你可以使用 netron
查看结构是否正确:
pip install netron
netron model/gpt2.onnx
确认 input_ids
、attention_mask
、logits
的动态 shape 是否正确映射。
❗ 常见报错与排查建议
报错信息 | 原因与解决方式 |
---|---|
RuntimeError: export with use_cache=True not supported | ✅ 使用 Wrapper 禁用 use_cache |
ONNX export failed: Can't export function 'forward' | 检查模型是否有自定义参数未处理 |
opset_version too low | 使用 opset_version=14 或更高 |
📌 总结
- ONNX 是后续 TensorRT / ONNXRuntime 加速部署的必要前提
- GPT-2 导出必须封装 Wrapper 禁用
use_cache
- 设置
dynamic_axes
支持动态 batch / sequence 推理 - 对应项目 gpt2-trt-deploy 的 export_to_onnx.py
- 下一篇将继续介绍如何基于此模型进行 TensorRT 加载与推理!
📎 本文为 GPT-2 项目加速部署系列第一篇
🧭 本系列 GPT-2 项目加速部署系列五部曲
- 🧩 第一篇:将 GPT-2 导出为 ONNX 模型:部署加速第一步
- 🚀 第二篇:用 TensorRT 加速 GPT-2 推理:ONNX 加载、CUDA 显存管理与性能优化实战
- 🌐 第三篇:PyTorch vs TensorRT 推理性能对比:GPT-2 加速效果实测报告
- 🧠 第四篇:用 Flask 封装 GPT-2 TensorRT 推理接口:构建可远程调用的文本生成服务
- 💼 第五篇:用 Golang 构建 GPT-2 前端服务:对接 Flask API 实现跨语言调用
📌 YoanAILab 技术导航页
💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页
📚 包含内容:
- 🧠 GPT-2 项目源码(GitHub)
- ✍️ CSDN 技术专栏合集
- 💼 知乎转型日志
- 📖 公众号 YoanAILab 全文合集