上一篇文章我曾经提到过,libtorch在进行图像实时检测中性能并不突出,很多时候无法满足我们的需求,后来我在想能不能现在pytorch上训练模型,只将libtorch作为一个加载模型的工具呢?经过尝试我发现这种方法是可行的,并且无论是在运行时间上还是在预测的准确率上都要优于前者,本文章将介绍如何在pytorch上训练模型并用libtorch进行加载预测。
pytorch安装
pytorch官方为我们提供了非常方便的安装渠道,在官网上选择适配自己电脑的选项后用官方给出的命令行下载即可,如下图:
验证安装成功
终端键入
python3
import torch
若如下图成功执行,即表示安装成功(记得选择对应的python版本)
pytorch构建网络
主要用到的就是torch中的nn、autograd、torchvision和tqdm模块,在nn.Module的基础上继承并进行构建,下面给出例程代码,同样以LeNet5为例:
class LeNet5(nn.Module):
def __init__(self, num_class):
super(LeNet5, self).__init__()
#卷积层
self.Conv = nn.Sequential(
nn.Conv2d(1, 6, 5, stride=1, padding=2),
nn.BatchNorm2d(6),
nn.ReLU(True),
#池化
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.BatchNorm2d(16),
nn.ReLU(True),
#池化
nn.MaxPool2d(2, 2)
)
#全连接层
self.FC = nn.Sequential(
nn.Linear(400, 120),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(120, 84),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(84, num_class)
)
def forward(self, x):
#卷积层
out = self.Conv(x)
out = out.view(out.size(0), -1)
#全连接层
out = self.FC(out)
return out
训练(核心代码)
# 向前传播
out = model(img)
# 计算损失函数
loss = criterion(out, label)
running_loss = loss.item() * label.size(0)
_, pred = torch.max(out, dim=1)
num_correct = (pred == label).sum()
running_acc += num_correct.item()
# 手动清空梯度
optimizer.zero_grad()
# 向后传播
loss.backward()
optimizer.step()
保存模型
# 保存模型
torch.save(model, '../model/Vision_NumDetect.pth')
# pth模型转成pt模型
with torch.no_grad():
model.eval()
trace_script_modile=torch.jit.trace(model, img)
trace_script_modile.save(r"../model/NumDetect.pt") #压缩好的模型存出来
以上均是在pytorch下完成的,然后再通过libtorch载入并预测:
libtorch载入模型
torch::jit::load(path);
预测
/**
* @brief 使用torch将图片传入模型中进行预测
* @return 返回预测结果
*/
int torchForward(torch::jit::script::Module &module, const Mat &src)
{
std::vector<int64_t> sizes = {1, 1, src.rows, src.cols};
at::TensorOptions options(at::ScalarType::Byte);
//将opencv的图像数据转为Tensor张量数据
at::Tensor tensor_image = torch::from_blob(src.data, at::IntList(sizes), options);
//转为浮点型张量数据
tensor_image = tensor_image.toType(at::kFloat);
// 前馈预测
at::Tensor result = module.forward({tensor_image}).toTensor();
auto max_result = result.max(1, true);
int max_index = std::get<1>(max_result).item<int>();
return max_index;
}
总结
由于libtorch是从pytorch上移植的,所以在构建网络的逻辑部分并不完善,构建出的网络性能对于pytorch较差,采用pytorch构建网络并训练模型,再用libtorch进行加载,既解决了实时性检测的性能问题,又解决在C++上的适配问题。如果有什么错误的地方,欢迎留言交流!
喜欢的话可以关注一下我的公众号技术开发小圈,尤其是对深度学习以及计算机视觉有兴趣的朋友,我会把相关的源码以及更多资料发在上面,希望可以帮助到新入门的大家!