libtorch (pytorch c++) 教程(六)

该教程介绍了如何使用libtorch(PyTorch的C++接口)构建U-Net模型,重点是ResNet编码器的搭建和U-Net解码器的设计。通过C++实现,展示了从基础模块到整个模型的构建过程,并且提到了模型在CPU和GPU上的执行效率。
摘要由CSDN通过智能技术生成

本教程分多个章节:


本章简要介绍如何如何用C++实现一个语义分割器模型,该模型具有训练和预测的功能。本文的分割模型架构使用简单的U-Net结构,代码结构参考了 qubvel segmentation中的U-Net部分,该项目简称SMP,是基于pytorch实现的开源语义分割项目。本文分享的c++模型几乎完美复现了python的版本。

模型简介

简单介绍一下U-Net模型。U-Net模型的提出是在医学图像分割中,相比于当时的其他模型结构,U-Net的分割能力具有明显优势。一个经典的U-Net结构图如下:
U-Net结构图

U-Net模型采用典型的编码器-解码器结构,左边的编码部分类似VGG模型,是双卷积+下采样的多次堆叠。U-Net模型右边的解码部分同样是双卷积,但是为了得到接近原始输入图像大小的输出图像,针对编码的下采样实施了对应的上采样。最重要的是,U-Net之所以效果突出,重要原因在于其在解码部分利用了编码环节的特征图,拼接编码和解码的特征图,再对拼接后特征图卷积上采样,重复多次得到解码输出。

编码器——ResNet

本文介绍的编码器使用ResNet网络,同时可以像第五章一样加载预训练权重,即骨干网络为ImageNet预训练的ResNet。话不多说,直接上c++的ResNet代码。

Block搭建

建议看本文代码时打开pytorch的torchvision中的resnet.py,对比阅读。

首先是基础模块,pytorch针对resnet18,resne34和resnet50,resnet101,resnet152进行分类,resnet18与resnet34均使用BasicBlock,而更深的网络使用BottleNeck。我不想使用模板类编程,就直接将两个模块合为一体。声明如下:

class BlockImpl : public torch::nn::Module {
   
public:
    BlockImpl(int64_t inplanes, int64_t planes, int64_t stride_ = 1,
        torch::nn::Sequential downsample_ = nullptr, int groups = 1, int base_width = 64, bool is_basic = true);
    torch::Tensor forward(torch::Tensor x);
    torch::nn::Sequential downsample{
    nullptr };
private:
    bool is_basic = true;
    int64_t stride = 1;
    torch::nn::Conv2d conv1{
    nullptr };
    torch::nn::BatchNorm2d bn1{
    nullptr };
    torch::nn::Conv2d conv2{
    nullptr };
    torch::nn::BatchNorm2d bn2{
    nullptr };
    torch::nn::Conv2d conv3{
    nullptr };
    torch::nn::BatchNorm2d bn3{
    nullptr };
};
TORCH_MODULE(Block);

可以发现,其实是直接声明了三个conv结构和一个is_basic标志位判断定义时进行BasicBlock定义还是BottleNeck定义。下面时其定义

BlockImpl::BlockImpl(int64_t inplanes, int64_t planes, int64_t stride_,
    torch::nn::Sequential downsample_, int groups, int base_width, bool _is_basic)
{
   
    downsample = downsample_;
    stride = stride_;
	int width = int(planes * (base_width / 64.)) * groups;

    conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 3, stride_, 1, groups, false));
    bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
    conv2 = torch::nn::Conv2d(conv_options(width, width, 3, 1, 1, groups, false));
    bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
    is_basic = _is_basic;
    if (!is_basic) {
   
        conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 1, 1, 0, 1, false));
        conv2 = torch::nn::Conv2d(conv_options(width, width, 3, stride_, 1, groups, false));
        conv3 = torch::nn::Conv2d(conv_options(width, planes * 4, 1, 1, 0, 1, false));
        bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * 4));
    }

    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("conv2", conv2);
    register_module("bn2", bn2);
    if (!is_basic) {
   
        register_module("conv3", conv3);
        register_module("bn3", bn3);
    }

    if (!downsample->is_empty()) {
   
        register_module("downsample", downsample);
    }
}

torch::Tensor BlockImpl::forward(torch::Tensor x) {
   
    torch::Tensor residual = x.clone();

    x = conv1->forward(x);
    x = bn1->forward(x);
    x = torch::relu(x);

    x = conv2->forward(x);
    x = bn2->forward(x);

    if (!is_basic) {
   
        x = torch::relu(x);
        x = conv3->forward(x);
        x = bn3->forward(x);
    }

    if (!downsample->is_empty()) {
   
        residual = downsample->forward(residual);
    }

    x += residual;
    x = torch::relu(x);

    return x;
}

然后不要忘了熟悉的conv_options函数,定义如下:

inline torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size,
    int64_t stride = 1, int64_t padding = 0, int groups = 1, bool with_
  • 11
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值