将3个并列的空洞卷积由3✖3更换为一个3✖1和一个1✖3
更换前代码:
import torch
import torch.nn.modules
import torch.nn as nn
from torch.nn import functional as F
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels):
super(ASPP, self).__init__()
self.mean = nn.AdaptiveAvgPool2d((1, 1))
self.conv = nn.Conv2d(in_channels,out_channels, 1, 1)
self.atrous_block1 = nn.Conv2d(in_channels, out_channels, 1, 1)
self.atrous_block6 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=6, dilation=6)
self.atrous_block12 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=12, dilation=12)
self.atrous_block18 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=18, dilation=18)
self.conv_1x1_output = nn.Conv2d(out_channels * 5, out_channels, 1, 1)
def forward(self, x):
size = x.shape[2:]
image_features = self.mean(x)
image_features = self.conv(image_features)
image_features = F.interpolate(image_features, size=size, mode='bilinear', align_corners=False)
atrous_block1 = self.atrous_block1(x)
atrous_block6 = self.atrous_block6(x)
atrous_block12 = self.atrous_block12(x)
atrous_block18 = self.atrous_block18(x)
net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6, atrous_block12, atrous_block18], dim=1))
return net
更换后的代码
import torch
import torch.nn.modules
import torch.nn as nn
from torch.nn import functional as F
class ASPP2(nn.Module):
def __init__(self, in_channels, out_channels):
super(ASPP2, self).__init__()
self.mean = nn.AdaptiveAvgPool2d((1, 1))
self.conv = nn.Conv2d(in_channels,out_channels, 1, 1)
self.atrous_block1 = nn.Conv2d(in_channels, out_channels, 1, 1)
self.atrous_block6_1 = nn.Conv2d(in_channels, out_channels, (3, 1), 1, padding=(6,0), dilation=6)
self.atrous_block6_2 = nn.Conv2d(in_channels, out_channels, (1, 3), 1, padding=(0,6), dilation=6)
self.atrous_block12_1 = nn.Conv2d(in_channels, out_channels, (3, 1), 1, padding=(12,0), dilation=12)
self.atrous_block12_2 = nn.Conv2d(in_channels, out_channels, (1, 3), 1, padding=(0,12), dilation=12)
self.atrous_block18_1 = nn.Conv2d(in_channels, out_channels, (3, 1), 1, padding=(18,0), dilation=18)
self.atrous_block18_2 = nn.Conv2d(in_channels, out_channels, (1, 3), 1, padding=(0,18), dilation=18)
self.conv_1x1_output = nn.Conv2d(out_channels * 5, out_channels, 1, 1)
def forward(self, x):
size = x.shape[2:]
image_features = self.mean(x)
image_features = self.conv(image_features)
image_features = F.interpolate(image_features, size=size, mode='bilinear', align_corners=False)
atrous_block1 = self.atrous_block1(x)
atrous_block6_1 = self.atrous_block6_1(x)
atrous_block6 = self.atrous_block6_2(atrous_block6_1)
atrous_block12_1 = self.atrous_block12_1(x)
atrous_block12 = self.atrous_block12_2(atrous_block12_1)
atrous_block18_1 = self.atrous_block18_1(x)
atrous_block18 = self.atrous_block18_2(atrous_block18_1)
net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6, atrous_block12, atrous_block18], dim=1))
return net
经实测,更换后的计算量减小