2021-05-11

对稠密连接网络(DenseNet)通道数的理解

Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (Vol. 1, No. 2).

https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter05_CNN/5.12_densenet

1.来源

DenseNet是ResNet的一个变种,在跨层连接上,不同于ResNet中将输入与输出相加,DenseNet在通道维上连结输入与输出。(如下图所示)

2.代码以及理解

2.1.卷积块(conv_block)

卷积块可以理解为稠密块的基本组成,卷积块=BN+ReLU+卷积。

import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm2d(in_channels), 
                        nn.ReLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk

到这一步很好理解,正常的输入输出通道。

2.2.稠密块(DenseBlock)

这一步有点复杂了。稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。不同的输入,相同的输出,可是卷积的要求就是下一层的输入等于上一层的输出啊,怎么可能输入相同输出不同?看一下代码。

class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结
        return X

这里可以这样理解,每一个卷积块不管输入多少通道,输出通道永远固定。而每一个输入通道数都是对之前输出通道数的累加。

我们记输入为X,输出为Y,卷积操作为f:

则第一个输出为:

Y1 = f(X)

第二个输出为:

Y2 = f(X, Y1)

第三个输出为:

Y3 = f(X, Y1, Y2)

这里面,Y1,Y2,Y3的通道数是固定的,但是输入实在第一个X的基础上每次都拼接上一次的输出,假设输入通道为3,输出通道为6,就是在每一次输入通道数都是在上一次输入通道数的基础上加上上一次的输出通道数。

但是请注意,前向传播最后返回的是X,而不是Y,因此,最后一个卷积块的输出通道为最终累加的那个值,具体公式为:

final_out_channels = in_channels + num_conv * out_channels

在下面的例子中,我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时,我们会得到通道数为3+2×10=23的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。

blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值