class DeepLabV3(_SimpleSegmentationModel):
# 继承_SimpleSegmentationModel类, 其分析见fcn_101理解
"""
Implements DeepLabV3 model from
`"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>`_.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
class DeepLabHead(nn.Sequential):
# 最后层,对提取的中间特征层进行整合
def __init__(self, in_channels, num_classes):
# num_classes=21
super(DeepLabHead, self).__init__(
ASPP(in_channels, [12, 24, 36]), # 三个膨胀率
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, 1)
)
class ASPPConv(nn.Sequential):
#ASPP的卷积层 conv(膨胀卷积,要输入膨胀率dilation)+bn+relu
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
#pooling层
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1), # 自适应平均池化,输出大小为1x1
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
size = x.shape[-2:] # 上采样大小
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
# ASPP层,
def __init__(self, in_channels, atrous_rates):
super(ASPP, self).__init__()
out_channels = 256
modules = [] # 模块list
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())) # 第一个Conv2d
rate1, rate2, rate3 = tuple(atrous_rates) # 三个膨胀率
modules.append(ASPPConv(in_channels, out_channels, rate1)) # 第二个ASPPConv
modules.append(ASPPConv(in_channels, out_channels, rate2)) # 第三个ASPPConv
modules.append(ASPPConv(in_channels, out_channels, rate3)) # 第四个ASPPConv
modules.append(ASPPPooling(in_channels, out_channels)) # # 第五个ASPPPooling
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), # 5*out_channels
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5))
def forward(self, x):
res = [] # 保存结果的list
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1) # 按通道c合并 torch.Size([1, 1280, 28, 28])
return self.project(res) # 返回该值self.project(x)
deeplab3_101代码理解
最新推荐文章于 2024-05-03 14:38:54 发布