一. Feature Extraction
- 原文描述:The left and right images are fed to a weight-sharing feature extraction pipeline. It consists of a
stacked hourglass CNN and is connected by concatenations. - 下面对应着结构图看代码
1.先看 GANet类
class GANet(nn.Module):
- 第一个输出 g 是训练权重的
guidance subnet
g = self.conv_start(x)
x = self.feature(x)
- 下面 解析 特征提取模块
self.feature
class Feature(nn.Module):
def __init__(self):
super(Feature, self).__init__()
self.conv_start = nn.Sequential(
BasicConv(3, 32, kernel_size=3, padding=1),
BasicConv(32, 32, kernel_size=5, stride=3, padding=2),
BasicConv(32, 32, kernel_size=3, padding=1))
- 先将RGB三通道转换成32通道
- 使用stride=3的卷积,使feature_size变为1/3.
- **kwargs 语法 将输入的值变为字典形式,python基础
- 图像尺寸降维的过程使用
BasicConv
方法。卷积参数:kernel_size=5, stride=3, padding=2,之后是框架的手段,下面接一层kernel_size=3,s=1,padding=1的conv2d
class BasicConv(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
super(BasicConv, self).__init__()
# print(in_channels, out_channels, deconv, is_3d, bn, relu, kwargs)
self.relu = relu
self.use_bn = bn
if is_3d:
if deconv:
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
self.bn = BatchNorm3d(out_channels)
else:
if deconv:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.relu:
x = F.relu(x, inplace=True)
return x
- 使用一系列stride=2的conv2d继续降维 feature_size(1/6、1/12、1/24、1/48),通道数变为64、96、128。
1)
self.conv1a = BasicConv(32, 48, kernel_size=3, stride=2, padding=1)
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
2)
self.conv1a = BasicConv(48, 64, kernel_size=3, stride=2, padding=1)
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
3)
```cpp
self.conv1a = BasicConv(64, 96, kernel_size=3, stride=2, padding=1)
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
4)
```cpp
self.conv1a = BasicConv(96 ,128, kernel_size=3, stride=2, padding=