Densenet网络
import torch
from torch import nn
def conv_block(in_channel,out_channel):
layer=nn.Sequential(
nn.BatchNorm2d(in_channel),
nn.ReLU(),
nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1,bias=False)
)
return layer
class dense_block(nn.Module):
def __init__(self,in_channel,growth_rate,num_layers):
super(dense_block, self).__init__()
block=[]
channel = in_channel
for i in range(num_layers):
block.append(conv_block(channel,growth_rate))
channel += growth_rate
self.net= nn.Sequential (*block)
def forward(self,x):
for layer in self.net:
out=layer(x)
x=torch.cat((out,x),dim=1)
return x
def transiton(in_channel,out_channel):
trans_layer=nn.Sequential(
nn.BatchNorm2d(in_channel),
nn.ReLU(),
nn.Conv2d(in_channel,out_channel,1),
nn.AvgPool2d(2,2)
)
return trans_layer
class densenet(nn.Module):
def __init__(self,in_channel,num_classes,growth_rate=32,block_layers=[6,12,24,16],init_weights=False):
super(densenet,self).__init__()
self.block1=nn.Sequential(
# nn.Conv2d(in_channel, out_channels=64, kernel_size=7, stride=2, padding=3),
nn.Conv2d(in_channel,64,7,2,3),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.MaxPool2d(3,2,padding=1)
)
self.DB1 = self._make_dense_block(64,growth_rate,num=block_layers[0])
self.TL1 = self._make_transiton_layer(256)
self.DB2 = self._make_dense_block(128,growth_rate,num=block_layers[1])
self.TL2 = self._make_transiton_layer(512)
self.DB3 = self._make_dense_block(256,growth_rate,num=block_layers[2])
self.TL3 = self._make_transiton_layer(1024)
self.DB4 = self._make_dense_block(512,growth_rate,num=block_layers[3])
self.global_average = nn.Sequential(
nn.BatchNorm2d(1024),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1)),
)
# self.classifier=nn.Linear(1024,num_classes)
self.classifier = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.3),
nn.Linear(512,256),
nn.ReLU(inplace=True),
nn.Dropout(p=0.4),
nn.Linear(256,num_classes),
)
if init_weights:
self._initialize_weights()
def forward(self,x):
x=self.block1(x)
x=self.DB1(x)
x=self.TL1(x)
x=self.DB2(x)
x=self.TL2(x)
x=self.DB3(x)
x=self.TL3(x)
x=self.DB4(x)
x=self.global_average(x)
x=x.view(x.shape[0],-1)
x=self.classifier(x)
return x
def _make_dense_block(self,channels,growth_rate,num):
block = []
block.append(dense_block(channels,growth_rate,num))
channels += num * growth_rate
return nn.Sequential(*block)
def _make_transiton_layer(self,channels):
block=[]
block.append(transiton(channels,channels//2))
return nn.Sequential(*block)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
net=densenet(in_channel=3,num_classes=10)
# net=densenet(3,10)
# print(net)
x=torch.rand(1,3,224,224)
print(net(x))
# for name,layer in net.named_children():
# if name!='classifier':
# x=layer(x)
# print(name,'output shape:',x.shape)
# else:
# x=x.view(x.size(0),-1)
# x=layer(x)
# print(name,'output shape',x.shape)