pytorch实现densnet

    class _DenseLayer(nn.Sequential):
      '''DenseBlock中的内部结构,这里是BN+ReLU+1x1 Conv+BN+ReLU+3x3 Conv结构,最后也加入dropout层以用于训练过程。'''
        def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
            super(_DenseLayer, self).__init__()
            self.add_module("norm1", nn.BatchNorm2d(num_input_features))
            self.add_module("relu1", nn.ReLU(inplace=True))
            self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate,
                                               kernel_size=1, stride=1, bias=False))
            self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate))
            self.add_module("relu2", nn.ReLU(inplace=True))
            self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate,
                                               kernel_size=3, stride=1, padding=1, bias=False))
            self.drop_rate = drop_rate
    
        def forward(self, x):
            new_features = super(_DenseLayer, self).forward(x)
            if self.drop_rate > 0:
                new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
            return torch.cat([x, new_features], 1)

#实现DenseBlock模块,内部是密集连接方式(输入特征数线性增长)
    class _DenseBlock(nn.Sequential):
        """DenseBlock"""
        def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
            super(_DenseBlock, self).__init__()
            for i in range(num_layers):
                layer = _DenseLayer(num_input_features+i*growth_rate, growth_rate, bn_size,
                                    drop_rate)
                self.add_module("denselayer%d" % (i+1,), layer)

#实现Transition层,它主要是一个卷积层和一个池化层
    class _Transition(nn.Sequential):
        """Transition layer between two adjacent DenseBlock"""
        def __init__(self, num_input_feature, num_output_features):
            super(_Transition, self).__init__()
            self.add_module("norm", nn.BatchNorm2d(num_input_feature))
            self.add_module("relu", nn.ReLU(inplace=True))
            self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features,
                                              kernel_size=1, stride=1, bias=False))
            self.add_module("pool", nn.AvgPool2d(2, stride=2))

#实现DenseNet网络
    class DenseNet(nn.Module):
        "DenseNet-BC model"
        def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
                     bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):
            """
            :param growth_rate: (int) number of filters used in DenseLayer, `k` in the paper
            :param block_config: (list of 4 ints) number of layers in each DenseBlock
            :param num_init_features: (int) number of filters in the first Conv2d
            :param bn_size: (int) the factor using in the bottleneck layer
            :param compression_rate: (float) the compression rate used in Transition Layer
            :param drop_rate: (float) the drop rate after each DenseLayer
            :param num_classes: (int) number of classes for classification
            """
            super(DenseNet, self).__init__()
            # first Conv2d
            self.features = nn.Sequential(OrderedDict([
                ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                ("norm0", nn.BatchNorm2d(num_init_features)),
                ("relu0", nn.ReLU(inplace=True)),
                ("pool0", nn.MaxPool2d(3, stride=2, padding=1))
            ]))
    
            # DenseBlock
            num_features = num_init_features
            for i, num_layers in enumerate(block_config):
                block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)
                self.features.add_module("denseblock%d" % (i + 1), block)
                num_features += num_layers*growth_rate
                if i != len(block_config) - 1:
                    transition = _Transition(num_features, int(num_features*compression_rate))
                    self.features.add_module("transition%d" % (i + 1), transition)
                    num_features = int(num_features * compression_rate)
    
            # final bn+ReLU
            self.features.add_module("norm5", nn.BatchNorm2d(num_features))
            self.features.add_module("relu5", nn.ReLU(inplace=True))
    
            # classification layer
            self.classifier = nn.Linear(num_features, num_classes)
    
            # params initialization
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.bias, 0)
                    nn.init.constant_(m.weight, 1)
                elif isinstance(m, nn.Linear):
                    nn.init.constant_(m.bias, 0)
    
        def forward(self, x):
            features = self.features(x)
            out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
            out = self.classifier(out)
            return out

#选择不同网络参数,就可以实现不同深度的DenseNet,这里实现DenseNet-121网络,而且Pytorch提供了预训练好的网络参数
    def densenet121(pretrained=False, **kwargs):
        """DenseNet121"""
        model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
                         **kwargs)
    
        if pretrained:
            # '.'s are no longer allowed in module names, but pervious _DenseLayer
            # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
            # They are also in the checkpoints in model_urls. This pattern is used
            # to find such keys.
            pattern = re.compile(
                r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
            state_dict = model_zoo.load_url(model_urls['densenet121'])
            for key in list(state_dict.keys()):
                res = pattern.match(key)
                if res:
                    new_key = res.group(1) + res.group(2)
                    state_dict[new_key] = state_dict[key]
                    del state_dict[key]
            model.load_state_dict(state_dict)
        return model

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值