目录
一、模型结构
假设使用(4,3,128,128)的输入,也就是4张图,三通道,图片大小128*128。
1.Backbone
本模型使用resnet101作为backbone,通过backbone之后,得到的输出是(4,2048,8,8),同时我们在resnet101中,第一个block之后,就把结果作为一个底层特征输出出来,形状为(4,256,32,32),这里就是和deeplabv3不一样的地方,这个输出将输入decoder,进行特征融合。
2.Encoder
就是ASPP模块,这个模块中,有1个1*1卷积,3个3*3的膨胀卷积,以及一个全局的pooling,把这五个结果合并起来,在通过一个1*1卷积,就得到了encoder的输出,形状是(4,256,8,8),这个输出将进入decoder,和前面backbone过来的数据进行融合。
3.Decoder
这部分首先把encoder来的特征图进行上采样,使其和底层特征的尺寸一致,将这两个合并之后,在进行卷积,最后再次进行4倍上采样得到最终的输出。
二、模型实现
这里面的F.interpolate就是上采样的方法,和nn.Upsample效果一样。
import torch.nn as nn
import torch
from resnet import ResNet101
import torch.nn.functional as F
# 使用resnet101作为模型的backbone
net = ResNet101()
# ASPP模块
class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation):
'''这里定义膨胀卷积'''
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(inplanes, planes,kernel_size=kernel_size,
padding=padding,dilation=dilation)
self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPP(nn.Module):
'''从resnet101的结果过来,通道2048
这里进行四个卷积+一个pooling,然后合并起来,最后再来一个1*1卷积'''
def __init__(self):
super(ASPP, self).__init__()
inplanes = 2048
dilations = [1, 6, 12, 18]
self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0])
self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1])
self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2])
self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3])
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1,1)),
nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
self.bn1 = nn.BatchNorm2d(256)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
# test = nn.Upsample(x4.size()[2:], mode='bilinear', align_corners=True)
# testout = test(x5)
x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat([x1,x2,x3,x4,x5], dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
class Decoder(nn.Module):
'''decoder部分,从resnet中取出的输出上采样后,与ASPP模块的输出合并,再进一步卷积+上采样'''
def __init__(self, num_classes):
super(Decoder, self).__init__()
low_level_inplanes = 256
self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
self.bn1 = nn.BatchNorm2d(48)
self.relu = nn.ReLU()
self.last_conv = nn.Sequential(
nn.Conv2d(304, 256, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
)
def forward(self,x,low_level_feat):
low_level_feat = self.conv1(low_level_feat)
low_level_feat = self.bn1(low_level_feat)
low_level_feat = self.relu(low_level_feat)
x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat([low_level_feat, x], dim=1)
x = self.last_conv(x)
return x
class DeepLab3p(nn.Module):
def __init__(self, num_classes):
super(DeepLab3p, self).__init__()
self.backbone = ResNet101()
self.aspp = ASPP()
self.decoder = Decoder(num_classes=num_classes)
def forward(self, x_in):
x,low_level_feat = self.backbone(x_in)
x = self.aspp(x)
x = self.decoder(x, low_level_feat)
x = F.interpolate(x, x_in.size()[2:], mode='bilinear', align_corners=True)
return x