将PyTorch模型转换成TFLite格式教程

一、基础环境搭建(全内网)

1. 安装 Miniconda(国内镜像)
 

bash

复制

# 使用中科大镜像下载(若失效可替换为阿里云/清华源)
wget http://mirrors.ustc.edu.cn/anaconda/miniconda/Miniconda3-py39_24.5.0-0-Linux-x86_64.sh

# 安装至默认路径(全程回车确认)
bash Miniconda3-py39_24.5.0-0-Linux-x86_64.sh
source ~/.bashrc  # 刷新环境变量
2. 配置国内镜像源
 

bash

复制

# Conda 清华镜像
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/

# Pip 阿里云镜像
mkdir -p ~/.pip
echo -e "[global]\nindex-url = https://mirrors.aliyun.com/pypi/simple/\ntrusted-host = mirrors.aliyun.com" > ~/.pip/pip.conf

二、创建专用环境(TensorFlow 2.10.0)

1. 创建并激活环境
 

bash

复制

conda create -n tf2.10 python=3.9 -y
conda activate tf2.10
2. 安装核心依赖(版本锁定)
 

bash

复制

# 安装 TensorFlow 2.10 及工具链
pip install tensorflow==2.10.0 \
    onnx==1.13.0 \
    onnxruntime==1.13.1 \
    onnx-tf==1.10.0 \
    torch==1.12.1+cpu -f https://download.pytorch.org/whl/torch_stable.html

如果PyTorch不能通过上述命令安装或者安装进度缓慢建议使用conda安装(推荐)

bash

# 添加PyTorch官方Conda源 
conda config --add channels pytorch 
conda config --append channels pytorch-test
 # 安装CPU版本 conda install -y pytorch==2.0.1 torchvision==0.15.2 cpuonly

三、模型转换全流程(PyTorch → TFLite)

1. 创建项目结构
 

bash

复制

mkdir -p ~/model_conversion/{models,scripts}
# 将预训练模型文件 (.pth) 放入 ~/model_conversion/models
2. 转换脚本 ~/model_conversion/scripts/convert.py
 

python

复制

import torch
import torch.nn as nn
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf

# 通用生成器架构(适配多数GAN模型)
class BaseGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),  # 输入通道3 (RGB)
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, padding=1),   # 输出通道3 (RGB)
            nn.Tanh()                         # 输出范围[-1,1]
        )

    def forward(self, x):
        return self.main(x)

# 加载预训练权重
model = BaseGenerator()
model.load_state_dict(torch.load('../models/generator.pth', map_location='cpu'))
model.eval()

# 导出ONNX(固定输入尺寸256x256)
dummy_input = torch.randn(1, 3, 256, 256)
torch.onnx.export(
    model, 
    dummy_input, 
    "../models/generator.onnx",
    input_names=['input'],
    output_names=['output'],
    opset_version=14  # 与TF2.10兼容的算子版本
)

# ONNX转TensorFlow
onnx_model = onnx.load("../models/generator.onnx")
tf_rep = prepare(onnx_model)
tf_rep.export_graph("../models/generator_tf")

# 生成TFLite模型(启用优化)
converter = tf.lite.TFLiteConverter.from_saved_model("../models/generator_tf")
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # 默认优化
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # 原生算子
    tf.lite.OpsSet.SELECT_TF_OPS      # TensorFlow算子
]
tflite_model = converter.convert()

with open('../models/generator.tflite', 'wb') as f:
    f.write(tflite_model)

print("✅ 转换完成!模型路径:~/model_conversion/models/generator.tflite")
3. 执行转换
 

bash

复制

cd ~/model_conversion/scripts
python convert.py

四、本地验证测试

1. 测试脚本 test.py
 

python

复制

import cv2
import numpy as np
import tensorflow as tf

def preprocess(img_path):
    # 预处理(与训练一致)
    img = cv2.imread(img_path)
    img = cv2.resize(img, (256, 256))          # 固定尺寸
    img = img.astype(np.float32) / 127.5 - 1  # 归一化到[-1,1]
    return np.expand_dims(img, axis=0)        # 添加batch维度

# 加载TFLite模型
interpreter = tf.lite.Interpreter(
    model_path='../models/generator.tflite',
    num_threads=4  # 多线程加速
)
interpreter.allocate_tensors()

# 准备输入
input_data = preprocess("test.jpg")
input_idx = interpreter.get_input_details()[0]['index']
interpreter.set_tensor(input_idx, input_data)

# 执行推理
interpreter.invoke()

# 获取输出并保存
output = interpreter.get_tensor(
    interpreter.get_output_details()[0]['index']
)
output_img = (output[0] + 1) * 127.5
cv2.imwrite("../output/result.jpg", output_img)
2. 运行测试
 

bash

复制

python test.py

五、错误解决方案

❌ ​​错误:Unsupported ONNX opset version: 14
 

python

复制

# 修改转换脚本中的 opset 版本
torch.onnx.export(..., opset_version=13)  # 降低版本
❌ ​​错误:Input shape mismatch
 

python

复制

# 在转换脚本中固定输入尺寸
dummy_input = torch.randn(1, 3, 256, 256)  # 与模型实际输入一致

六、系统级依赖补全

 

bash

复制

# Ubuntu/Debian
sudo apt-get install -y \
    libgl1-mesa-glx \    # OpenCV硬件加速
    libsm6 \              # GUI支持
    libxext6

# CentOS/RHEL
sudo yum install -y \
    mesa-libGL \
    libXext \
    libSM

关键验证点

  1. ​环境验证​​:
     

    bash

    复制

    python -c "import tensorflow as tf; print(tf.__version__)"  # 应输出 2.10.0
  2. ​模型输入/输出尺寸​​:
     

    bash

    复制

    tflite_model = interpreter._model  # 通过 Netron 可视化检查
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值