import torch
import torch.nn as nn
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
class DenseNet(nn.Module):
dense_block = 4
num_layer_per_block = 4
droup_out = 0
use_bottleneck = False
compression_factor = 1.0
fc_feats = [2]
def __init__(self,growth_rate = 8,activate_fn = nn.ReLU,normalization=nn.BatchNorm2d):
super(DenseNet,self).__init__()
# 先初始化第一层的densenet
self.initial = DenseNetInitialLayer(growth_rate,activate_fn,normalization)
c_now = self.initial.c_now
# 初始化dense block
for i in range(self.dense_block):
i_ = i + 1
self.add_module("block"+str(i_),DenseNetBlock(c_now,
num_layer=(self.num_layer_per_block)/2 if self.use_bottleneck else self.num_layer_per_block,
growth_rate=growth_rate,
p_drop=self.droup_out,
activate_fn=activate_fn,
normalization=normalization,
use_bottleneck=self.use_bottleneck
)
)
c_now = list(self.children())[-1].c_now
if i < (self.dense_block - 1):
self.add_module("trans"+str(i_),DenseNetTransitionLayer(c_now,
p_drop = self.droup_out,
compress_factor = self.compression_factor,
activate_fn=activate_fn,
normalization=normalization))
c_now = list(self.children())[-1].c_now
self.fcs = []
f_now = c_now
for f in self.fc_feats:
fc = nn.Linear(f_now,f).to(device)
fc.weight.data.norm_(0,0.01)
fc.bias.fill_(0)
self.fcs.append(fc)
f_now = f
self.fcs = nn.ModuleList(self.fcs)
def forward(self,x):
for name,module in self.named_children():
if name == "fcs":
break
x = module(x)
x = torch.mean(x,dim=1)
x = torch.mean(x,dim=2)
for fc in self.fcs:
x = fc(x)
return x
class DenseNetTransitionLayer(nn.Module):
def __init__(self,c_in,p_drop,compression_factor,activate_fn,normalization):
super(DenseNetTransitionLayer,self).__init__()
c_out = int(compression_factor*c_in)
self.composition = DenseNetComposLayer(c_in,c_out,kernel_size=1,p_drop=p_drop,activate_fn=activate_fn,normalization=normalization)
self.pool = nn.AvgPool2d(kernel_size=2,stride=2)
self.c_now = c_out
def forward(self,x):
x = self.composition(x)
x = self.pool(x)
return x
class DenseNetBlock(nn.Module):
def __init__(self,c_in,num_layer,growth_rate,p_drop,activate_fn,normalization,use_bottleneck,
transposed = False):
super(DenseNetBlock,self).__init__()
c_now = c_in
self.use_bottleneck = use_bottleneck
for i in range(num_layer):
i_ = i + 1
if self.use_bottleneck:
self.add_module("bneck%d" % i_,DenseNetComposLayer(c_now,4*growth_rate,
kernel_size=1,p_drop=p_drop,
activate_fn=activate_fn,
normalization=normalization))
self.add_module("compo%d" % i_,DenseNetComposLayer(4*growth_rate if self.use_bottleneck else c_now,
growth_rate,kernel_size=3,
activate_fn=activate_fn,
normalization=normalization,
transposed=transposed))
c_now += list(self.children())[-1].c_now
self.c_now = c_now
def forward(self,x):
x_before = x
for name,module in self.modules():
if ((self.use_bottleneck and name.startswith("bneck")) or name.startswith("compo")):
x_before = x
x = module(x)
if name.startswith("compo"):
x = torch.cat([x_before,x],dim=1)
return x
class DenseNetComposLayer(nn.Module):
def __init__(self,c_in,c_out,kernel_size,p_drop,activate_fn,normalization,transposed=False):
super(DenseNetComposLayer,self).__init__()
self.p_drop = p_drop
self.norm = normalization(c_in,track_running_stats=False).to(device)
self.act = activate_fn(inplace=True)
if transposed:
assert kernel_size > 1
self.conv = nn.ConvTranspose2d(c_in,c_out,kernel_size=kernel_size,padding=1 if kernel_size>1 else 0,
stride=1,bias=False).to(device)
else:
self.conv = nn.Conv2d(c_in,c_out,kernel_size=kernel_size,stride=1,
padding=1 if kernel_size>1 else 0,bias=False).to(device)
nn.init.kaiming_normal_(self.conv.weight.data)
self.drop = nn.Dropout2d(p_drop)
self.c_now = c_out
def forward(self,x):
x = self.norm(x)
x = self.act(x)
x = self.conv(x)
if self.p_drop is not None:
x = self.drop(x)
return x
class DenseNetInitialLayer(nn.Module):
def __init__(self,growth_rate=8,activate_fn=nn.ReLU,normalization=nn.BatchNorm2d):
super(DenseNetInitialLayer,self).__init__()
c_now = 2*growth_rate
self.conv1 = nn.Conv2d(3,c_now,kernel_size=3,padding=1,stride=2,bias=False)
nn.init.kaiming_normal_(self.conv1.weight.data)
self.act = activate_fn(inplace=True)
self.norm = normalization(c_now,track_running_stats=False).to(device)
c_out = 4*growth_rate
self.c_now = c_out
self.conv2 = nn.Conv2d(c_now,c_out,kernel_size=3,padding=1,stride=2,bias=False)
nn.init.kaiming_normal_(self.conv2.weight.data)
self.c_list = [c_now,c_out]
def forward(self,x):
x = self.conv1(x)
x = self.norm(x)
x = self.act(x)
pred_x = x
x = self.conv2(x)
return x,pred_x```
DenseNet代码复现
最新推荐文章于 2024-07-20 17:12:48 发布