我这里实现的是密集连接的多尺度金字塔模块
def conv_block(in_channels, out_channels, kernel_size):
blk = nn.Sequential(nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=1))
return blk
稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。
class DenseBlock(nn.Module):
def __init__(self, num_convs, in_channels, out_channels,
kernel_sizes=[1,3,5,7]):
super(DenseBlock, self).__init__()
net = []
for i in range(num_convs):
in_c = in_channels + i * out_channels
net.append(conv_block(in_c, out_channels, kernel_size=kernel_sizes[i]))
self.net = nn.ModuleList(net)
self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数
def forward(self, X):
for blk in self.net:
Y = blk(X)
X = torch.cat((X, Y), dim=1) # 在通道维上将输入和输出连结
return X