ResNet极大地改变了有关如何参数化深度网络功能的观点,2017年提出的DenseNet在某种程度上是这一概念的逻辑扩展.
公式1
resnet的方式
densennet和resnet的不同地方
renent是相加,densenet是并列.
用代码实现以下densenet
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()
def conv_block(num_channels):
blk = nn.Sequential()
blk.add(nn.BatchNorm(),
nn.Activation('relu'),
nn.Conv2D(num_channels, kernel_size=3, padding=1))
return blk
class DenseBlock(nn.Block):
def __init__(self, num_convs, num_channels, **kwargs):
super().__init__(**kwargs)
self.net = nn.Sequential()
for _ in range(num_convs):
self.net.add(conv_block(num_channels))
def forward(self, X):
for blk in self.net:
Y = blk(X)
# Concatenate the input and output of each block on the channel
# dimension
X = np.concatenate((X, Y), axis=1)
return X
做一下测试
blk = DenseBlock(2, 10)
blk.initialize()
X = np.random.uniform(size=(4, 3, 8, 8))
Y = blk(X)
Y.shape
(4, 23, 8, 8)
densenet专门设计了减小网络通道数的结构
def transition_block(num_channels):
blk = nn.Sequential()
blk.add(nn.BatchNorm(), nn.Activation('relu'),
nn.Conv2D(num_channels, kernel_size=1),
nn.AvgPool2D(pool_size=2, strides=2))
return blk
注意上边的pool层采用AvgPool2D, 没有采用MaxPool2D是用于把重点放在整体特征的传递上,
接下来,创建一个完成的densenet
net = nn.Sequential()
net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
nn.BatchNorm(), nn.Activation('relu'),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
# `num_channels`: the current number of channels
num_channels, growth_rate = 64, 32
num_convs_in_dense_blocks = [4, 4, 4, 4]
for i, num_convs in enumerate(num_convs_in_dense_blocks):
net.add(DenseBlock(num_convs, growth_rate))
# This is the number of output channels in the previous dense block
num_channels += num_convs * growth_rate
# A transition layer that halves the number of channels is added between
# the dense blocks
if i != len(num_convs_in_dense_blocks) - 1:
num_channels //= 2
net.add(transition_block(num_channels))
net.add(nn.BatchNorm(),
nn.Activation('relu'),
nn.GlobalAvgPool2D(),
nn.Dense(10))
训练一下
lr, num_epochs, batch_size = 0.1, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr)
loss 0.144, train acc 0.947, test acc 0.881
5505.3 examples/sec on gpu(0)
最后的话:
这篇文章发布在CSDN/蓝色的杯子, 没事多留言,让我们一起爱智求真吧.我的邮箱wisdomfriend@126.com.