一、基础环境搭建(全内网)
1. 安装 Miniconda(国内镜像)
bash
复制
# 使用中科大镜像下载(若失效可替换为阿里云/清华源)
wget http://mirrors.ustc.edu.cn/anaconda/miniconda/Miniconda3-py39_24.5.0-0-Linux-x86_64.sh
# 安装至默认路径(全程回车确认)
bash Miniconda3-py39_24.5.0-0-Linux-x86_64.sh
source ~/.bashrc # 刷新环境变量
2. 配置国内镜像源
bash
复制
# Conda 清华镜像
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
# Pip 阿里云镜像
mkdir -p ~/.pip
echo -e "[global]\nindex-url = https://mirrors.aliyun.com/pypi/simple/\ntrusted-host = mirrors.aliyun.com" > ~/.pip/pip.conf
二、创建专用环境(TensorFlow 2.10.0)
1. 创建并激活环境
bash
复制
conda create -n tf2.10 python=3.9 -y
conda activate tf2.10
2. 安装核心依赖(版本锁定)
bash
复制
# 安装 TensorFlow 2.10 及工具链
pip install tensorflow==2.10.0 \
onnx==1.13.0 \
onnxruntime==1.13.1 \
onnx-tf==1.10.0 \
torch==1.12.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
如果PyTorch不能通过上述命令安装或者安装进度缓慢建议使用conda安装(推荐)
bash
# 添加PyTorch官方Conda源
conda config --add channels pytorch
conda config --append channels pytorch-test
# 安装CPU版本 conda install -y pytorch==2.0.1 torchvision==0.15.2 cpuonly
三、模型转换全流程(PyTorch → TFLite)
1. 创建项目结构
bash
复制
mkdir -p ~/model_conversion/{models,scripts}
# 将预训练模型文件 (.pth) 放入 ~/model_conversion/models
2. 转换脚本 ~/model_conversion/scripts/convert.py
python
复制
import torch
import torch.nn as nn
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
# 通用生成器架构(适配多数GAN模型)
class BaseGenerator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), # 输入通道3 (RGB)
nn.ReLU(inplace=True),
nn.Conv2d(64, 3, 3, padding=1), # 输出通道3 (RGB)
nn.Tanh() # 输出范围[-1,1]
)
def forward(self, x):
return self.main(x)
# 加载预训练权重
model = BaseGenerator()
model.load_state_dict(torch.load('../models/generator.pth', map_location='cpu'))
model.eval()
# 导出ONNX(固定输入尺寸256x256)
dummy_input = torch.randn(1, 3, 256, 256)
torch.onnx.export(
model,
dummy_input,
"../models/generator.onnx",
input_names=['input'],
output_names=['output'],
opset_version=14 # 与TF2.10兼容的算子版本
)
# ONNX转TensorFlow
onnx_model = onnx.load("../models/generator.onnx")
tf_rep = prepare(onnx_model)
tf_rep.export_graph("../models/generator_tf")
# 生成TFLite模型(启用优化)
converter = tf.lite.TFLiteConverter.from_saved_model("../models/generator_tf")
converter.optimizations = [tf.lite.Optimize.DEFAULT] # 默认优化
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # 原生算子
tf.lite.OpsSet.SELECT_TF_OPS # TensorFlow算子
]
tflite_model = converter.convert()
with open('../models/generator.tflite', 'wb') as f:
f.write(tflite_model)
print("✅ 转换完成!模型路径:~/model_conversion/models/generator.tflite")
3. 执行转换
bash
复制
cd ~/model_conversion/scripts
python convert.py
四、本地验证测试
1. 测试脚本 test.py
python
复制
import cv2
import numpy as np
import tensorflow as tf
def preprocess(img_path):
# 预处理(与训练一致)
img = cv2.imread(img_path)
img = cv2.resize(img, (256, 256)) # 固定尺寸
img = img.astype(np.float32) / 127.5 - 1 # 归一化到[-1,1]
return np.expand_dims(img, axis=0) # 添加batch维度
# 加载TFLite模型
interpreter = tf.lite.Interpreter(
model_path='../models/generator.tflite',
num_threads=4 # 多线程加速
)
interpreter.allocate_tensors()
# 准备输入
input_data = preprocess("test.jpg")
input_idx = interpreter.get_input_details()[0]['index']
interpreter.set_tensor(input_idx, input_data)
# 执行推理
interpreter.invoke()
# 获取输出并保存
output = interpreter.get_tensor(
interpreter.get_output_details()[0]['index']
)
output_img = (output[0] + 1) * 127.5
cv2.imwrite("../output/result.jpg", output_img)
2. 运行测试
bash
复制
python test.py
五、错误解决方案
❌ 错误:Unsupported ONNX opset version: 14
python
复制
# 修改转换脚本中的 opset 版本
torch.onnx.export(..., opset_version=13) # 降低版本
❌ 错误:Input shape mismatch
python
复制
# 在转换脚本中固定输入尺寸
dummy_input = torch.randn(1, 3, 256, 256) # 与模型实际输入一致
六、系统级依赖补全
bash
复制
# Ubuntu/Debian
sudo apt-get install -y \
libgl1-mesa-glx \ # OpenCV硬件加速
libsm6 \ # GUI支持
libxext6
# CentOS/RHEL
sudo yum install -y \
mesa-libGL \
libXext \
libSM
关键验证点
- 环境验证:
bash
复制
python -c "import tensorflow as tf; print(tf.__version__)" # 应输出 2.10.0
- 模型输入/输出尺寸:
bash
复制
tflite_model = interpreter._model # 通过 Netron 可视化检查