DenseNet代码复现

import torch
import torch.nn as nn
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
class DenseNet(nn.Module):
    dense_block = 4
    num_layer_per_block = 4
    droup_out = 0
    use_bottleneck = False
    compression_factor = 1.0
    fc_feats = [2]

    def __init__(self,growth_rate = 8,activate_fn = nn.ReLU,normalization=nn.BatchNorm2d):
        super(DenseNet,self).__init__()
        # 先初始化第一层的densenet
        self.initial = DenseNetInitialLayer(growth_rate,activate_fn,normalization)
        c_now = self.initial.c_now

        # 初始化dense block
        for i in range(self.dense_block):
            i_ = i + 1
            self.add_module("block"+str(i_),DenseNetBlock(c_now,
                                                          num_layer=(self.num_layer_per_block)/2 if self.use_bottleneck else self.num_layer_per_block,
                                                          growth_rate=growth_rate,
                                                          p_drop=self.droup_out,
                                                          activate_fn=activate_fn,
                                                          normalization=normalization,
                                                          use_bottleneck=self.use_bottleneck
                                                          )

                            )
            c_now = list(self.children())[-1].c_now

            if i < (self.dense_block - 1):
                self.add_module("trans"+str(i_),DenseNetTransitionLayer(c_now,
                                                                        p_drop = self.droup_out,
                                                                        compress_factor = self.compression_factor,
                                                                        activate_fn=activate_fn,
                                                                        normalization=normalization))
                c_now = list(self.children())[-1].c_now

            self.fcs = []
            f_now = c_now
            for f in self.fc_feats:
                fc = nn.Linear(f_now,f).to(device)
                fc.weight.data.norm_(0,0.01)
                fc.bias.fill_(0)
                self.fcs.append(fc)
                f_now = f
            self.fcs = nn.ModuleList(self.fcs)

    def forward(self,x):
        for name,module in self.named_children():
            if name == "fcs":
                break
            x = module(x)

        x = torch.mean(x,dim=1)
        x = torch.mean(x,dim=2)

        for fc in self.fcs:
            x = fc(x)

        return x

class DenseNetTransitionLayer(nn.Module):
    def __init__(self,c_in,p_drop,compression_factor,activate_fn,normalization):
        super(DenseNetTransitionLayer,self).__init__()
        c_out = int(compression_factor*c_in)
        self.composition = DenseNetComposLayer(c_in,c_out,kernel_size=1,p_drop=p_drop,activate_fn=activate_fn,normalization=normalization)
        self.pool = nn.AvgPool2d(kernel_size=2,stride=2)
        self.c_now = c_out

    def forward(self,x):
        x = self.composition(x)
        x = self.pool(x)
        return x


class DenseNetBlock(nn.Module):
    def __init__(self,c_in,num_layer,growth_rate,p_drop,activate_fn,normalization,use_bottleneck,
                 transposed = False):
        super(DenseNetBlock,self).__init__()
        c_now = c_in
        self.use_bottleneck = use_bottleneck
        for i in range(num_layer):
            i_ = i + 1
            if self.use_bottleneck:
                self.add_module("bneck%d" % i_,DenseNetComposLayer(c_now,4*growth_rate,
                                                                   kernel_size=1,p_drop=p_drop,
                                                                   activate_fn=activate_fn,
                                                                   normalization=normalization))
            self.add_module("compo%d" % i_,DenseNetComposLayer(4*growth_rate if self.use_bottleneck else c_now,
                                                               growth_rate,kernel_size=3,
                                                               activate_fn=activate_fn,
                                                               normalization=normalization,
                                                               transposed=transposed))
            c_now += list(self.children())[-1].c_now
            self.c_now = c_now

    def forward(self,x):
        x_before = x
        for name,module in self.modules():
            if ((self.use_bottleneck and name.startswith("bneck")) or name.startswith("compo")):
                x_before = x
            x = module(x)
            if name.startswith("compo"):
                x = torch.cat([x_before,x],dim=1)
        return x


class DenseNetComposLayer(nn.Module):
    def __init__(self,c_in,c_out,kernel_size,p_drop,activate_fn,normalization,transposed=False):
        super(DenseNetComposLayer,self).__init__()
        self.p_drop = p_drop
        self.norm = normalization(c_in,track_running_stats=False).to(device)
        self.act = activate_fn(inplace=True)
        if transposed:
            assert kernel_size > 1
            self.conv = nn.ConvTranspose2d(c_in,c_out,kernel_size=kernel_size,padding=1 if kernel_size>1 else 0,
                                           stride=1,bias=False).to(device)
        else:
            self.conv = nn.Conv2d(c_in,c_out,kernel_size=kernel_size,stride=1,
                                  padding=1 if kernel_size>1 else 0,bias=False).to(device)

        nn.init.kaiming_normal_(self.conv.weight.data)
        self.drop = nn.Dropout2d(p_drop)
        self.c_now = c_out

    def forward(self,x):
        x = self.norm(x)
        x = self.act(x)
        x = self.conv(x)
        if self.p_drop is not None:
            x = self.drop(x)
        return x



class DenseNetInitialLayer(nn.Module):
    def __init__(self,growth_rate=8,activate_fn=nn.ReLU,normalization=nn.BatchNorm2d):
        super(DenseNetInitialLayer,self).__init__()
        c_now = 2*growth_rate
        self.conv1 = nn.Conv2d(3,c_now,kernel_size=3,padding=1,stride=2,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight.data)
        self.act = activate_fn(inplace=True)
        self.norm = normalization(c_now,track_running_stats=False).to(device)
        c_out = 4*growth_rate
        self.c_now = c_out
        self.conv2 = nn.Conv2d(c_now,c_out,kernel_size=3,padding=1,stride=2,bias=False)
        nn.init.kaiming_normal_(self.conv2.weight.data)
        self.c_list = [c_now,c_out]

    def forward(self,x):
        x = self.conv1(x)
        x = self.norm(x)
        x = self.act(x)
        pred_x = x
        x = self.conv2(x)
        return x,pred_x```

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

leoliyao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值