DenseNet Pytorch实现

DenseNet网络实现

DenseNet和ResNet不同在于ResNet是跨层求和,而DenseNet是跨层将特征在通道维度进行拼接,下图一是ResNet,二是DenseNet。
在这里插入图片描述
因为实在通道维度进行特征的拼接,所以底层的输出会保留进入后面的曾,这样能更好的保证梯度的传播,同时能够使用低维的特征和高维的特征进行联合训练,能够得到更好的结果。

DenseNet主要有Dense block组成,使用pytorch实现如下

def conv_block(in_channel,	out_channel):
				layer	=	nn.Sequential(
								nn.BatchNorm2d(in_channel),
								nn.ReLU(True),
								nn.Conv2d(in_channel,	out_channel,	3,	padding=1,	bias=False)
				)
				return layer


class dense_block(nn.Module):
    def __init__(self, in_channel, growth_rate, num_layers):
        super(dense_block, self).__init__()
        block = []
        channel = in_channel
        for i in range(num_layers):
            block.append(conv_block(channel, growth_rate))
            channel += growth_rate

        self.net = nn.Sequential(*block)

    def forward(self, x):
        for layer in self.net:
            out = layer(x)
            x = torch.cat((out, x), dim=1)
        return x

定义DenseNet:

class densenet(nn.Module):
    def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12,
                                                                              24, 16]):
        super(densenet, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channel, 64, 7, 2, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(3, 2, padding=1)
        )

        channels = 64
        block = []
        for i, layers in enumerate(block_layers):
            block.append(dense_block(channels, growth_rate, layers))
            channels += layers * growth_rate
            if i != len(block_layers) - 1:
                block.append(transition(channels, channels // 2))  # ᭗ᬦ	transition	੶


    channels = channels // 2
    self.block2 = nn.Sequential(*block)
    self.block2.add_module('bn', nn.BatchNorm2d(channels))
    self.block2.add_module('relu', nn.ReLU(True))
    self.block2.add_module('avg_pool', nn.AvgPool2d(3))
    self.classifier = nn.Linear(channels, num_classes)
    
def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)

    x = x.view(x.shape[0], -1)
    x = self.classifier(x)
    return x
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值