立体匹配-- GA-Net网络结构代码剖析

在这里插入图片描述

一. Feature Extraction

在这里插入图片描述

  • 原文描述:The left and right images are fed to a weight-sharing feature extraction pipeline. It consists of a
    stacked hourglass CNN and is connected by concatenations.
  • 下面对应着结构图看代码
    1.先看 GANet类
class GANet(nn.Module):
  • 第一个输出 g 是训练权重的guidance subnet
 g = self.conv_start(x)	
 x = self.feature(x)
  • 下面 解析 特征提取模块 self.feature

在这里插入图片描述

class Feature(nn.Module):
    def __init__(self):
        super(Feature, self).__init__()

        self.conv_start = nn.Sequential(
            BasicConv(3, 32, kernel_size=3, padding=1),
            BasicConv(32, 32, kernel_size=5, stride=3, padding=2),
            BasicConv(32, 32, kernel_size=3, padding=1))
  • 先将RGB三通道转换成32通道
  • 使用stride=3的卷积,使feature_size变为1/3.
  • **kwargs 语法 将输入的值变为字典形式,python基础
  • 图像尺寸降维的过程使用BasicConv方法。卷积参数:kernel_size=5, stride=3, padding=2,之后是框架的手段,下面接一层kernel_size=3,s=1,padding=1的conv2d
class BasicConv(nn.Module):

    def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
        super(BasicConv, self).__init__()
#        print(in_channels, out_channels, deconv, is_3d, bn, relu, kwargs)
        self.relu = relu
        self.use_bn = bn
        if is_3d:
            if deconv:
                self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
            else:
                self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
            self.bn = BatchNorm3d(out_channels)
        else:
            if deconv:
                self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
            else:
                self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
            self.bn = BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x
  • 使用一系列stride=2的conv2d继续降维 feature_size(1/6、1/12、1/24、1/48),通道数变为64、96、128。
    1)
  self.conv1a = BasicConv(32, 48, kernel_size=3, stride=2, padding=1)
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)

2)

self.conv1a = BasicConv(48, 64, kernel_size=3, stride=2, padding=1)
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
3)
```cpp
  self.conv1a = BasicConv(64, 96, kernel_size=3, stride=2, padding=1)
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
4)
```cpp
  self.conv1a = BasicConv(96128, kernel_size=3, stride=2, padding=
  • 5
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值