深度学习————模型保存与部署

第一部分:模型保存基础

什么是模型保存?

当你训练好一个深度学习模型后,它会拥有“学习到的参数”,这些参数(权重、偏置等)构成了模型的“知识”。如果不保存这些参数,那么训练好的模型在关闭程序后就会丢失。

所以,模型保存就是将训练好的参数(或整个模型)保存到磁盘上,供之后加载使用或部署


 两种主要保存方式(以 PyTorch 为例)

  1. 保存模型参数(推荐)

    • 只保存模型的状态字典(state_dict),这是最推荐的方法。

    • 更轻便,适合部署、版本管理。

    • 加载时需要重新构建模型结构,然后加载参数。

  2. 保存整个模型结构与参数(不推荐)

    • 使用 torch.save(model) 直接保存整个模型对象。

    • 不可跨 Python 版本或环境,不利于调试与迁移。


 常见保存格式

框架推荐保存格式说明
PyTorch.pth / .pt.pth 一般用于 state_dict
TensorFlow.ckpt / .pb / SavedModelSavedModel 用于部署
ONNX.onnx便于跨框架、跨平台部署


保存路径与命名建议

  • 路径统一、版本可控,如:

    checkpoints/
    ├── model_v1_2025-05-01.pth
    ├── model_best_val.pth
    └── model_latest.pth
    
  • 可使用时间戳 + 性能指标命名,便于后续追踪:

    model_acc87.4_epoch15.pth
    

 版本管理建议

  • 使用日志系统(如 TensorBoard、WandB)记录对应版本表现。

  • 每次训练完成后保存多个模型:如最优验证集模型(best)、最后模型(last)。

  • 大项目建议结合 Git 和 DVC(Data Version Control)管理模型文件。

第二部分:PyTorch 中的模型保存与加载实战

PyTorch 提供了非常灵活和强大的模型保存与加载机制,主要通过 state_dict(模型参数字典)进行操作。下面我们详细讲解每一步并提供示例代码。


 一、什么是 state_dict

state_dict 是一个 Python 字典,保存了模型中每一层的参数(权重和偏置等)。它的格式大致如下:

{
  'layer1.weight': tensor(...),
  'layer1.bias': tensor(...),
  ...
}

每个模块(如 nn.Linear, nn.Conv2d)都将其参数注册在 state_dict 中。


🔹 二、保存模型参数(推荐)

保存代码:

import torch

# 假设你有一个模型实例 model
torch.save(model.state_dict(), 'model.pth')

注意事项:

  • model.pth 只是文件名,扩展名可以是 .pt.pth,没有区别。

  • 只保存参数,不包含模型结构,因此加载时需要手动定义结构。


 三、加载模型参数

加载步骤分两步走:

  1. 重新定义模型结构;

  2. 加载参数到模型中。

    # 1. 定义模型结构(必须与保存时一致)
    model = MyModel()
    
    # 2. 加载参数
    model.load_state_dict(torch.load('model.pth'))
    
    # 3. 切换到评估模式(部署时必须)
    model.eval()
    

🔹 四、保存整个模型(不推荐)

torch.save(model, 'entire_model.pth')

然后加载:

model = torch.load('entire_model.pth')
model.eval()

缺点:

  • 依赖于模型的类定义和 Python 环境;

  • 一旦结构变动,加载可能出错;

  • 不适合跨平台部署。


五、训练状态一起保存(含优化器)

训练中断后可恢复继续训练,需要同时保存模型和优化器状态。

# 保存
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, 'checkpoint.pth')

加载时:

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

 六、保存和加载到指定设备(如 GPU)

# 保存时无关设备
torch.save(model.state_dict(), 'model.pth')

# 加载到 CPU
model.load_state_dict(torch.load('model.pth', map_location='cpu'))

# 加载到 GPU
device = torch.device('cuda')
model.load_state_dict(torch.load('model.pth', map_location=device))

七、完整示例(含模型结构)

import torch
import torch.nn as nn

# 模型定义
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 初始化模型与优化器
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())

# 保存
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'model_checkpoint.pth')

# 加载
model2 = MyModel()
optimizer2 = torch.optim.Adam(model2.parameters())
checkpoint = torch.load('model_checkpoint.pth')
model2.load_state_dict(checkpoint['model_state_dict'])
optimizer2.load_state_dict(checkpoint['optimizer_state_dict'])

 

第三部分:模型部署基础概念

模型训练完成后,并不是终点。部署模型的目的,是将它放到现实世界中,为用户或系统提供服务。比如:

  • 智能客服系统:用户发送一条消息,模型给出回复;

  • 医疗图像诊断:上传CT图像,模型输出预测结果;

  • 教学系统:上传作业照片,模型识别题目并自动评分。

本部分将介绍部署的核心概念及常见方式。


 一、为什么需要部署?

目标说明
模型服务化把训练好的模型变成一个可以实时调用的服务(如 API)
多用户访问支持多个用户、多个终端访问(Web、App等)
实时推理对输入进行实时预测,如语音识别、图像识别
系统集成将模型集成进现有的软件系统、产品或平台
规模扩展支持大规模并发调用,进行推理加速、负载均衡等

 二、部署分类与对比

我们按部署场景将常见方式进行分类总结:

部署方式描述适合场景优点缺点
本地部署模型运行在本地电脑或服务器上开发测试、小项目简单易操作,无需网络不易扩展,不适合多人使用
Web 服务部署封装成 HTTP API / Web UI实际产品,后台系统可远程访问,适合用户使用部署较复杂,对安全性有要求
云端部署部署到云服务器(如阿里云、AWS)大型项目、商业部署可弹性伸缩,服务稳定成本高,涉及 DevOps 知识
移动端部署模型打包到手机或嵌入式设备移动AI、边缘设备离线可用,低延迟受限于算力、平台兼容性
服务器集群部署结合容器与负载均衡器部署多个模型高并发、高可用场景可自动扩缩、容错性好依赖 Docker/K8s,配置复杂

 三、部署方式常用工具和框架

场景工具/平台示例简述
本地部署Flask、Gradio、Streamlit简单封装模型为 API 或 Web 界面
Web 后端部署FastAPI、Flask + Gunicorn可高性能提供 REST 接口
云服务部署HuggingFace Spaces、阿里云 ECS快速上线,适合演示和产品原型
模型导出与推理加速TorchScript、ONNX、TensorRT优化模型结构,提高推理速度
多模型管理MLflow、TorchServe、NVIDIA Triton模型托管、版本管理与部署平台

 四、常见部署架构图示意(文字版)

用户 -> 浏览器 / App
      |
      V
   [ Web 前端 ]        ←(Gradio / React + Flask 等)
      |
      V
   [ Web 后端 API ]
      |
      V
   [ 推理服务(模型加载) ]
      |
      V
   [ 模型参数 / 权重文件 ]

 五、从训练到部署流程总览

  1. 训练模型:在本地或服务器完成训练;

  2. 保存模型:保存为 .pth.onnx 文件;

  3. 封装接口:使用 Flask / Gradio / FastAPI 编写服务;

  4. 构建前端(可选):使用 HTML / React / Gradio 交互;

  5. 部署上线:本地测试通过后部署到服务器或平台;

  6. 用户使用:通过网页、App 等方式访问部署的服务。

第四部分:模型导出为部署格式(TorchScript 和 ONNX)

训练好的 PyTorch 模型需要导出成标准格式,才能跨平台、跨框架、高效地部署。TorchScriptONNX 是 PyTorch 中最常用的导出格式。

本部分将详细讲解两者的概念、区别、导出方式及使用场景。


 一、为什么要导出模型?

虽然 .pth 格式在 PyTorch 内部很方便使用,但部署时常常需要:

  • 加快推理速度

  • 在没有 Python 的环境中运行

  • 与其他框架(如 TensorFlow、C++、移动端)兼容

  • 更稳定、更可控的模型格式

这时就需要导出为中间格式,如 TorchScript 或 ONNX。


 TorchScript 模型导出

 什么是 TorchScript?

TorchScript 是 PyTorch 的一个中间表示,它允许模型以静态图的形式保存并运行。这使得:

  • 可脱离 Python 环境运行

  • 可通过 C++ API 部署

  • 支持推理优化(如 torch.jit.optimize_for_inference


 TorchScript 两种转换方式

1. 追踪法(Tracing)

适合无条件分支的模型。

import torch

# 假设 model 是你训练好的模型
model.eval()

example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)

# 保存
traced_model.save("model_traced.pt")
2. 脚本法(Scripting)

适合包含 if/else、循环等逻辑的模型。

scripted_model = torch.jit.script(model)
scripted_model.save("model_scripted.pt")

💡 TorchScript 加载与推理

import torch

model = torch.jit.load("model_traced.pt")
model.eval()

output = model(torch.randn(1, 3, 224, 224))

ONNX 模型导出

什么是 ONNX?

ONNX(Open Neural Network Exchange)是一种通用模型格式,由微软和 Facebook 发起,支持多种深度学习框架,如:

  • PyTorch

  • TensorFlow

  • MXNet

  • OpenCV DNN

  • ONNX Runtime

  • TensorRT


PyTorch 转 ONNX 示例

import torch

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model, 
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
    opset_version=11
)

ONNX 模型验证

你可以用 onnx 包验证导出是否成功:

import onnx

onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)  # 抛出异常说明有问题

 推理:ONNX Runtime

import onnxruntime as ort
import numpy as np

session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

output = session.run(None, {input_name: input_data})

 TorchScript vs ONNX 对比总结

特性TorchScriptONNX
支持框架PyTorch 本身 + C++多框架(TensorRT、ONNX RT等)
性能优化支持是(官方提供优化接口)是(ONNX Runtime / TensorRT)
转换复杂度简单稍复杂,需要注意版本/OP集
支持 Python 控制流否(静态图模型)
移植性中(依赖 PyTorch 环境)强(适合工业部署)
推荐场景内部 PyTorch 部署跨平台、商业部署

第五部分:模型部署方式详解(Gradio、Flask、ONNX Runtime等)

在本部分,我们将从实用角度出发,逐一讲解几种常用部署方式,并配合完整代码模板,帮助你快速上手部署一个推理服务。


 方式一:使用 Gradio 快速构建 Web 界面

Gradio 是一个非常流行的 Python 库,用于快速构建交互式 Web 应用,适合演示、测试和初步上线。


 1. 安装 Gradio
pip install gradio

2. 代码示例:图像分类模型部署(TorchScript)
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image

# 加载 TorchScript 模型
model = torch.jit.load("model_traced.pt")
model.eval()

# 图像预处理函数
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# 推理函数
def predict(img):
    img = preprocess(img).unsqueeze(0)
    with torch.no_grad():
        output = model(img)
        probs = torch.nn.functional.softmax(output[0], dim=0)
    return {f"Class {i}": float(p) for i, p in enumerate(probs)}

# 创建界面
iface = gr.Interface(fn=predict, inputs="image", outputs="label")
iface.launch()

 启动后会自动打开浏览器访问地址,如:http://127.0.0.1:7860


方式二:使用 Flask 构建 RESTful 接口(API)

Flask 是 Python 中常用的 Web 框架,可以把模型封装成一个 HTTP 接口供前端或其他服务调用。


 1. 安装 Flask
pip install flask

 2. API 接口部署模板(适合 ONNX)
from flask import Flask, request, jsonify
import onnxruntime as ort
import numpy as np

app = Flask(__name__)

# 初始化 ONNX 推理器
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name

@app.route("/predict", methods=["POST"])
def predict():
    data = request.json["input"]  # 输入是嵌套 list
    input_array = np.array(data).astype(np.float32)
    output = session.run(None, {input_name: input_array})
    return jsonify({"output": output[0].tolist()})

if __name__ == '__main__':
    app.run(debug=True, port=5000)

前端可以通过 POST 请求向 /predict 发送数据,返回 JSON 格式的模型输出。


 方式三:部署到 HuggingFace Spaces(在线部署平台)

HuggingFace 提供免费的部署平台,支持 Gradio/Streamlit 应用的在线托管。


步骤:
  1. 在 https://huggingface.co/spaces 创建一个新的 Space;

  2. 选择 Gradio 模板;

  3. 上传你的代码文件(如 app.py)和 requirements.txt

  4. 提交后等待构建,即可访问。

示例

gradio
torch
torchvision

 方式四:ONNX Runtime + FastAPI + Docker(工业部署)

适合构建高性能、可扩展的 API 服务。

  • 使用 FastAPI 替代 Flask(性能更高);

  • 使用 Docker 打包(环境一致性);

  • 使用 ONNX Runtime(加速推理);

 若你感兴趣,我可以提供该方式的完整项目结构与部署脚本。


 常见部署注意事项

问题/注意点说明
模型文件太大可用 torch.quantization 压缩模型
GPU/CPU 版本不一致部署前明确目标环境是否支持 CUDA
接口响应慢FastAPI + Uvicorn 替代 Flask
高并发请求处理困难使用 Gunicorn 或 Docker + Kubernetes
数据预处理慢把预处理逻辑也放在服务端完成
服务崩溃/异常退出加入异常处理与日志记录系统

 

第六部分:高级部署与优化技巧(模型压缩、推理加速、Docker 打包、前端集成)

当你完成了模型部署的基本流程,进一步优化部署效果(速度、稳定性、易用性)就很关键了。下面我们从多个方面介绍进阶技巧。


 一、模型压缩与推理加速

部署模型时,常常遇到模型太大、推理太慢、占用资源高等问题。可以通过以下几种方式进行模型压缩推理加速


1. 模型量化(Quantization)

将浮点数权重压缩成更小的数据类型(如 float16int8),大幅降低模型大小和推理耗时。

静态量化(Post-training)示例:

import torch.quantization

model_fp32 = ...  # 已训练模型
model_fp32.eval()

# 准备量化配置
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model_fp32, inplace=True)

# 运行一次推理(用于收集统计信息)
_ = model_fp32(torch.randn(1, 3, 224, 224))

# 转换为量化模型
quantized_model = torch.quantization.convert(model_fp32, inplace=False)

# 保存
torch.jit.script(quantized_model).save("model_quant.pt")

 注意:部分模块(如 BatchNorm)不支持直接量化,需使用 QuantStub 包装。


2. 使用 Torch-TensorRT(NVIDIA GPU 加速)

Torch-TensorRT 是 NVIDIA 供的一个库,可将 TorchScript 模型转换为 TensorRT 引擎

pip install torch-tensorrt -U

简单使用:

import torch_tensorrt

trt_model = torch_tensorrt.compile(model, inputs=[torch.randn(1, 3, 224, 224).to("cuda")], enabled_precisions={torch.float16})

✅ 二、Docker 化部署(推荐生产环境使用)

Docker 可以把你的服务打包成镜像,确保环境一致性、可移植性。


1. 创建项目目录结构
deploy_app/
├── app.py               # Flask / Gradio 应用
├── model.onnx           # 导出的模型
├── requirements.txt     # 所需 Python 包
└── Dockerfile           # Docker 构建脚本

2. Dockerfile 示例
FROM python:3.10

WORKDIR /app

COPY requirements.txt .
RUN pip install -r requirements.txt

COPY . .

CMD ["python", "app.py"]

3. 构建并运行容器
docker build -t my_model_app .
docker run -p 5000:5000 my_model_app

 若部署到云端(如阿里云、AWS),推荐结合 Nginx 反向代理与容器编排(如 docker-compose 或 Kubernetes)。


 三、前端集成与美化建议

技术优点示例用途
Gradio快速搭建交互界面原型演示、测试用
Streamlit数据可视化友好图像/表格/图表展示等
HTML + JS适合自定义界面、美化展示嵌入 Web 系统、企业平台
React/Vue高度定制、适合商用产品构建完整 Web 应用

 四、完整部署案例:PyTorch → ONNX → Gradio → Docker → HuggingFace Spaces


总结与建议

部分内容概览
第一部分模型保存格式(权重、结构、完整模型)
第二部分加载与恢复模型的多种方式
第三部分部署的基本概念与分类
第四部分模型导出为 TorchScript / ONNX
第五部分使用 Gradio / Flask / ONNX Runtime 部署
第六部分模型压缩、推理加速、Docker 化、高级部署建议

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值