参考文章
问题
pytorch 的神经网络模型有很多,但 libtorch 就特别少。现在面临的问题是要在 C++ 环境下应用神经网络模型,肯定不能直接使用 pytorch 模型。解决办法有两个:
-
方法一是用 TorchScript 工具导出模型 poolnet.pt,模型中包含网络结构和参数权重,因此可以直接在 C++ 里面生成神经网络。
-
方法二是用 C++ 复现网络结构,封装为为类对象,再从 poolnet.pt 中导入参数权重。
对于神经网络模型 PoolNet ,将其应用到 C++ 环境下进行视频处理,下面这是前 10 帧画面处理时间。明显看出,方法一前两次运行时间很长,从第三帧开始,两种方法的处理时间几乎相同。但是,方法一相当简单,导出模型即可,方法二需要复现网络结构,工程量巨大。下面重点介绍方法一。
TorchScript 工具介绍
必定要看 官方文档。上面介绍了 trace 和 script 的区别。
PyTorch 导出模型
resnet50
编辑 export.py 文件,以 pytorch 提供的 resnet50 为例,分别使用 trace 和 script 导出模型。trace 需要提供一个输入样例,script 则不需要。但是复杂的模型使用 script 一般会失败,但 trace 可以。trace 和 script 导出的模型几乎没有区别,缺点是前两次处理时间都格外久。
import torch
from torchvision.models import resnet50
net = resnet50(pretrained=True)
net = net.cuda()
net.eval()
for key, value in net.named_parameters():
print(key)
# trace
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
traced_module = torch.jit.trace(net, x)
traced_module.save("resnet50_trace.pt")
# script
scripted_module = torch.jit.script(net)
scripted_module.save("resnet50_script.pt")
在 python3+pytorch 的虚拟环境下执行
python export.py
net
上面的例子是 pytorch 提供的 resnet50,如果是自己写的模型,可以按照下面的方式来。其中,net.pth 是训练后保存的参数,net.pt 则是期望导出的模型,使用 trace 方法。
import torch
import torchvision
# 初始化神经网络
net = Net()
net.load_state_dict(torch.load("net.pth"))
net.cuda()
net.eval()
# 导出模型
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
m = torch.jit.trace(net, x)
m.save("net.pt")
其中
m.save("net.pt")
也可写为
torch.jit.save(m, "net.pt")
C++ 中调用模型
C++ 使用 libtorch 时,一般使用 CMake 进行管理(参考 Pytorch 官网教程)。下面是在 C++ 环境中调用模型的方法。
#include <torch/torch.h>
#include <torch/script.h>
torch::Device device(torch::kCUDA);
// image.rows, image.cols 高在前,宽在后
torch::Tensor img_tensor = torch::from_blob(img.data, {1, image.rows, image.cols, 3}, torch::kByte).to(device);
img_tensor = img_tensor.permute({0, 3, 1, 2});
img_tensor = img_tensor.toType(torch::kFloat);
img_tensor = img_tensor.div(255.0);
torch::jit::script::Module net = torch::jit::load("../models/net.pt");
// 打印模型中的参数
for (const auto& pair : net.named_parameters()) {
std::cout << pair.name << " " << pair.value.requires_grad() << std::endl;
}
net.to(device)
torch::NoGradGuard no_grad;
torch::Tensor output = net.forward({img_tensor}).toTensor();