Deeplab v3+的结构代码简要分析--Pytorch 版

转载自:https://www.cnblogs.com/ywheunji/p/10479019.html。侵删

添加了解码模块来重构精确的图像物体边界。对比如图

 

 

deeplab v3+采用了与deeplab v3类似的多尺度带洞卷积结构ASPP,然后通过上采样,以及与不同卷积层相拼接,最终经过卷积以及上采样得到结果。

deeplab v3:

基于提出的编码-解码结构,可以任意通过控制 atrous convolution 来输出编码特征的分辨率,来平衡精度和运行时间(已有编码-解码结构不具有该能力.).

可以用来挖掘不同尺度的上下文信息

PSPNet 对不同尺度的网络进行池化处理,处理多尺度的上下文内容信息

deeplab v3+以resnet101为backbone

 

复制代码

  1 import math
  2 import torch
  3 import torch.nn as nn
  4 import torch.nn.functional as F
  5 import torch.utils.model_zoo as model_zoo
  6 from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
  7 
  8 BatchNorm2d = SynchronizedBatchNorm2d
  9 
 10 class Bottleneck(nn.Module):
      #'resnet网络的基本框架’
 11     expansion = 4
 12 
 13     def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
 14         super(Bottleneck, self).__init__()
 15         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
 16         self.bn1 = BatchNorm2d(planes)
 17         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
 18                                dilation=dilation, padding=dilation, bias=False)
 19         self.bn2 = BatchNorm2d(planes)
 20         self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
 21         self.bn3 = BatchNorm2d(planes * 4)
 22         self.relu = nn.ReLU(inplace=True)
 23         self.downsample = downsample
 24         self.stride = stride
 25         self.dilation = dilation
 26 
 27     def forward(self, x):
 28         residual = x
 29 
 30         out = self.conv1(x)
 31         out = self.bn1(out)
 32         out = self.relu(out)
 33 
 34         out = self.conv2(out)
 35         out = self.bn2(out)
 36         out = self.relu(out)
 37 
 38         out = self.conv3(out)
 39         out = self.bn3(out)
 40 
 41         if self.downsample is not None:
 42             residual = self.downsample(x)
 43 
 44         out += residual
 45         out = self.relu(out)
 46 
 47         return out
 48 
 49 class ResNet(nn.Module):
 50   #renet网络的构成部分
 51     def __init__(self, nInputChannels, block, layers, os=16, pretrained=False):
 52         self.inplanes = 64
 53         super(ResNet, self).__init__()
 54         if os == 16:
 55             strides = [1, 2, 2, 1]
 56             dilations = [1, 1, 1, 2]
 57             blocks = [1, 2, 4]
 58         elif os == 8:
 59             strides = [1, 2, 1, 1]
 60             dilations = [1, 1, 2, 2]
 61             blocks = [1, 2, 1]
 62         else:
 63             raise NotImplementedError
 64 
 65         # Modules
 66         self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3,
 67                                 bias=False)
 68         self.bn1 = BatchNorm2d(64)
 69         self.relu = nn.ReLU(inplace=True)
 70         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 71 
 72         self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0])
 73         self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])
 74         self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])
 75         self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3])
 76 
 77         self._init_weight()
 78 
 79         if pretrained:
 80             self._load_pretrained_model()
 81 
 82     def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
 83         downsample = None
 84         if stride != 1 or self.inplanes != planes * block.expansion:
 85             downsample = nn.Sequential(
 86                 nn.Conv2d(self.inplanes, planes * block.expansion,
 87                           kernel_size=1, stride=stride, bias=False),
 88                 BatchNorm2d(planes * block.expansion),
 89             )
 90 
 91         layers = []
 92         layers.append(block(self.inplanes, planes, stride, dilation, downsample))
 93         self.inplanes = planes * block.expansion
 94         for i in range(1, blocks):
 95             layers.append(block(self.inplanes, planes))
 96 
 97         return nn.Sequential(*layers)
 98 
 99     def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, dilation=1):
100         downsample = None
101         if stride != 1 or self.inplanes != planes * block.expansion:
102             downsample = nn.Sequential(
103                 nn.Conv2d(self.inplanes, planes * block.expansion,
104                           kernel_size=1, stride=stride, bias=False),
105                 BatchNorm2d(planes * block.expansion),
106             )
107 
108         layers = []
109         layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, downsample=downsample))
110         self.inplanes = planes * block.expansion
111         for i in range(1, len(blocks)):
112             layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i]*dilation))
113 
114         return nn.Sequential(*layers)
115 
116     def forward(self, input):
117         x = self.conv1(input)
118         x = self.bn1(x)
119         x = self.relu(x)
120         x = self.maxpool(x)
121 
122         x = self.layer1(x)
123         low_level_feat = x
124         x = self.layer2(x)
125         x = self.layer3(x)
126         x = self.layer4(x)
127         return x, low_level_feat
128 
129     def _init_weight(self):
130         for m in self.modules():
131             if isinstance(m, nn.Conv2d):
132                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133                 m.weight.data.normal_(0, math.sqrt(2. / n))
134             elif isinstance(m, BatchNorm2d):
135                 m.weight.data.fill_(1)
136                 m.bias.data.zero_()
137 
138     def _load_pretrained_model(self):
139         pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
140         model_dict = {}
141         state_dict = self.state_dict()
142         for k, v in pretrain_dict.items():
143             if k in state_dict:
144                 model_dict[k] = v
145         state_dict.update(model_dict)
146         self.load_state_dict(state_dict)
147 
148 def ResNet101(nInputChannels=3, os=16, pretrained=False):
149     model = ResNet(nInputChannels, Bottleneck, [3, 4, 23, 3], os, pretrained=pretrained)
150     return model
151 
152 
153 class ASPP_module(nn.Module):
   #ASpp模块的组成
154     def __init__(self, inplanes, planes, dilation):
155         super(ASPP_module, self).__init__()
156         if dilation == 1:
157             kernel_size = 1
158             padding = 0
159         else:
160             kernel_size = 3
161             padding = dilation
162         self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
163                                             stride=1, padding=padding, dilation=dilation, bias=False)
164         self.bn = BatchNorm2d(planes)
165         self.relu = nn.ReLU()
166 
167         self._init_weight()
168 
169     def forward(self, x):
170         x = self.atrous_convolution(x)
171         x = self.bn(x)
172 
173         return self.relu(x)
174 
175     def _init_weight(self):
176         for m in self.modules():
177             if isinstance(m, nn.Conv2d):
178                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
179                 m.weight.data.normal_(0, math.sqrt(2. / n))
180             elif isinstance(m, BatchNorm2d):
181                 m.weight.data.fill_(1)
182                 m.bias.data.zero_()
183 
184 
185 class DeepLabv3_plus(nn.Module):
  #正式开始deeplabv3+的结构组成
186     def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, freeze_bn=False, _print=True):
187         if _print:
188             print("Constructing DeepLabv3+ model...")
189             print("Backbone: Resnet-101")
190             print("Number of classes: {}".format(n_classes))
191             print("Output stride: {}".format(os))
192             print("Number of Input Channels: {}".format(nInputChannels))
193         super(DeepLabv3_plus, self).__init__()
194 
195         # Atrous Conv  首先获得从resnet101中提取的features map
196         self.resnet_features = ResNet101(nInputChannels, os, pretrained=pretrained)
197 
198         # ASPP,挑选参数
199         if os == 16:
200             dilations = [1, 6, 12, 18]
201         elif os == 8:
202             dilations = [1, 12, 24, 36]
203         else:
204             raise NotImplementedError
205     #四个不同带洞卷积的设置,获取不同感受野
206         self.aspp1 = ASPP_module(2048, 256, dilation=dilations[0])
207         self.aspp2 = ASPP_module(2048, 256, dilation=dilations[1])
208         self.aspp3 = ASPP_module(2048, 256, dilation=dilations[2])
209         self.aspp4 = ASPP_module(2048, 256, dilation=dilations[3])
210 
211         self.relu = nn.ReLU()
212     #全局平均池化层的设置
213         self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
214                                              nn.Conv2d(2048, 256, 1, stride=1, bias=False),
215                                              BatchNorm2d(256),
216                                              nn.ReLU())
217 
218         self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
219         self.bn1 = BatchNorm2d(256)
220 
221         # adopt [1x1, 48] for channel reduction.
222         self.conv2 = nn.Conv2d(256, 48, 1, bias=False)
223         self.bn2 = BatchNorm2d(48)
224     #结构图中的解码部分的最后一个3*3的卷积块
225         self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
226                                        BatchNorm2d(256),
227                                        nn.ReLU(),
228                                        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
229                                        BatchNorm2d(256),
230                                        nn.ReLU(),
231                                        nn.Conv2d(256, n_classes, kernel_size=1, stride=1))
232         if freeze_bn:
233             self._freeze_bn()
234   #前向传播
235     def forward(self, input):
236         x, low_level_features = self.resnet_features(input)
237         x1 = self.aspp1(x)
238         x2 = self.aspp2(x)
239         x3 = self.aspp3(x)
240         x4 = self.aspp4(x)
241         x5 = self.global_avg_pool(x)
242         x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
243     #把四个ASPP模块以及全局池化层拼接起来
244         x = torch.cat((x1, x2, x3, x4, x5), dim=1)
245     #上采样
246         x = self.conv1(x)
247         x = self.bn1(x)
248         x = self.relu(x)
249         x = F.upsample(x, size=(int(math.ceil(input.size()[-2]/4)),
250                                 int(math.ceil(input.size()[-1]/4))), mode='bilinear', align_corners=True)
251 
252         low_level_features = self.conv2(low_level_features)
253         low_level_features = self.bn2(low_level_features)
254         low_level_features = self.relu(low_level_features)
255 
256      #拼接低层次的特征,然后再通过插值获取原图大小的结果
257         x = torch.cat((x, low_level_features), dim=1)
258         x = self.last_conv(x)
259         x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
260 
261         return x
262 
263     def _freeze_bn(self):
264         for m in self.modules():
265             if isinstance(m, BatchNorm2d):
266                 m.eval()
267 
268     def _init_weight(self):
269         for m in self.modules():
270             if isinstance(m, nn.Conv2d):
271                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
272                 m.weight.data.normal_(0, math.sqrt(2. / n))
273             elif isinstance(m, BatchNorm2d):
274                 m.weight.data.fill_(1)
275                 m.bias.data.zero_()
276 
277 def get_1x_lr_params(model):
278     """
279     This generator returns all the parameters of the net except for
280     the last classification layer. Note that for each batchnorm layer,
281     requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
282     any batchnorm parameter
283     """
284     b = [model.resnet_features]
285     for i in range(len(b)):
286         for k in b[i].parameters():
287             if k.requires_grad:
288                 yield k
289 
290 
291 def get_10x_lr_params(model):
292     """
293     This generator returns all the parameters for the last layer of the net,
294     which does the classification of pixel into classes
295     """
296     b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
297     for j in range(len(b)):
298         for k in b[j].parameters():
299             if k.requires_grad:
300                 yield k
301 
302 
303 if __name__ == "__main__":
304     model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=True, _print=True)
305     model.eval()
306     image = torch.randn(1, 3, 512, 512)
307     with torch.no_grad():
308         output = model.forward(image)
309     print(output.size())

复制代码

  • 1
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 如果你想在 PyTorch 中实现 DeepLab v3,你可以使用 PyTorch 内置的卷积神经网络 (CNN) 层和其他相关的工具,根据 DeepLab v3结构和操作,逐步构建你的模型。 首先,你可以使用 PyTorch 实现 DeepLab v3 的主干特征提取网络,这通常是一个由多个卷积和池化层组成的序列。 其次,你可以在特征提取网络的末尾加入通道分离卷积,以获得高分辨率的特征图。 最后,你可以在高分辨率的特征图上使用 Atrous 空间金字塔池化 (ASPP) 模块来获得不同感受野的特征,并对它们进行融合,最终得到分割结果。 如果你对 PyTorch 不是很熟悉,建议先学习一些 PyTorch 的基础知识和操作,然后再尝试实现 DeepLab v3。 ### 回答2: DeepLab V3是一种目标语义分割模型,它基于深度学习技术,能够将图像中的每个像素分类到不同的语义类别中。下面是一个用PyTorch编写DeepLab V3代码的基本步骤: 1. 导入所需的库和模块: ```python import torch import torch.nn as nn import torchvision.models as models ``` 2. 定义DeepLab V3模型类,并继承PyTorch的nn.Module类: ```python class DeepLabV3(nn.Module): def __init__(self, num_classes): super(DeepLabV3, self).__init__() # 在此处加载预训练的模型(如ResNet) self.base_model = models.resnet101(pretrained=True) # 在此处添加DeepLab特有的层 # ... # ... # ... # ... # 定义分类器 self.classifier = nn.Conv2d(......) # 定义Softmax层 self.softmax = nn.Softmax(dim=1) def forward(self, x): # 在此处实现前向传播过程 # ... # ... # ... return output ``` 3. 在forward方法中实现模型的前向传播过程,其中可以利用预训练的模型进行特征提取,并在此基础上添加DeepLab特有的层。 4. 可以根据任务的不同需求,在forward方法的最后添加分类器层和Softmax层,以得到最终的预测结果。 5. 创建DeepLab V3模型的实例并加载数据进行训练或预测: ```python model = DeepLabV3(num_classes=20) # 在此处传入类别数目 input = torch.randn(1, 3, 224, 224) # 根据实际需求调整输入的尺寸和通道数 output = model(input) ``` 这只是DeepLab V3代码的基本框架,实际编写过程还需要根据具体的模型结构、数据集以及训练、预测等任务进行调整和优化。 ### 回答3: Deeplab V3是一种常用的语义分割模型,可以用于对图像中的每个像素进行分类,将其标记为不同的对象或区域。在PyTorch中实现Deeplab V3代码主要涉及到以下几个步骤: 1. 数据集准备:首先,需要准备语义分割的训练数据集和测试数据集。数据集应该包含图像和对应的像素级标签。可以使用PyTorch的`torchvision.datasets`或自定义的数据集类来加载数据。 2. 模型定义:在PyTorch中,可以使用`torch.nn`模块定义Deeplab V3的网络结构。网络的主要组成部分包括卷积层、池化层和解码层。可以参考论文或官方实现来了解Deeplab V3的具体结构,并在PyTorch中进行实现。 3. 损失函数定义:为了训练模型,需要定义一个适合语义分割任务的损失函数,如交叉熵损失函数。可以使用`torch.nn`模块中提供的损失函数,或自定义一个适合自己数据集的损失函数。 4. 训练过程:使用准备好的数据集、定义好的模型和损失函数,可以进行训练过程。在每个训练批次中,将输入图像传递给模型,得到预测的像素级标签,并计算损失。根据损失来更新模型的参数,可以使用PyTorch提供的优化器,如Adam或SGD。 5. 测试过程:在训练完成后,可以使用测试数据集来评估模型的性能。将测试图像输入模型,得到预测的像素级标签,并与真实标签进行比较,计算评估指标,如IoU (Intersection over Union)。 以上是使用PyTorch实现Deeplab V3的基本步骤。在实际代码编写过程中,还需要考虑数据的预处理、数据加载的批处理、数据增强、学习率调整等方面的处理,以提高模型的性能和训练效果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值