对稠密连接网络(DenseNet)通道数的理解
Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (Vol. 1, No. 2).
https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter05_CNN/5.12_densenet
1.来源
DenseNet是ResNet的一个变种,在跨层连接上,不同于ResNet中将输入与输出相加,DenseNet在通道维上连结输入与输出。(如下图所示)
![](https://i.loli.net/2021/05/11/Z8Hzvasd7ql2P5U.png)
2.代码以及理解
2.1.卷积块(conv_block)
卷积块可以理解为稠密块的基本组成,卷积块=BN+ReLU+卷积。
import time
import torch
from torch import nn, optim
import torch.nn.functional as F
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def conv_block(in_channels, out_channels):
blk = nn.Sequential(nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
return blk
到这一步很好理解,正常的输入输出通道。
2.2.稠密块(DenseBlock)
这一步有点复杂了。稠密块由多个conv_block
组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。不同的输入,相同的输出,可是卷积的要求就是下一层的输入等于上一层的输出啊,怎么可能输入相同输出不同?看一下代码。
class DenseBlock(nn.Module):
def __init__(self, num_convs, in_channels, out_channels):
super(DenseBlock, self).__init__()
net = []
for i in range(num_convs):
in_c = in_channels + i * out_channels
net.append(conv_block(in_c, out_channels))
self.net = nn.ModuleList(net)
self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数
def forward(self, X):
for blk in self.net:
Y = blk(X)
X = torch.cat((X, Y), dim=1) # 在通道维上将输入和输出连结
return X
这里可以这样理解,每一个卷积块不管输入多少通道,输出通道永远固定。而每一个输入通道数都是对之前输出通道数的累加。
我们记输入为X,输出为Y,卷积操作为f:
则第一个输出为:
Y1 = f(X)
第二个输出为:
Y2 = f(X, Y1)
第三个输出为:
Y3 = f(X, Y1, Y2)
…
这里面,Y1,Y2,Y3的通道数是固定的,但是输入实在第一个X的基础上每次都拼接上一次的输出,假设输入通道为3,输出通道为6,就是在每一次输入通道数都是在上一次输入通道数的基础上加上上一次的输出通道数。
但是请注意,前向传播最后返回的是X,而不是Y,因此,最后一个卷积块的输出通道为最终累加的那个值,具体公式为:
final_out_channels = in_channels + num_conv * out_channels
在下面的例子中,我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时,我们会得到通道数为3+2×10=23的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。
blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])