pytorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
def Conv1(inchannels,outchannels,stride=2):
return nn.Sequential(
nn.Conv2d(inchannels,outchannels,kernel_size=7,stride=stride,padding=3,bias=False),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=stride,padding=1)
)
class Bottleneck(nn.Module):
def __init__(self, nChannels, growthRate):
super(Bottleneck, self).__init__()
interClannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,interClannels,kernel_size=1,bias=False)
self.bn2 = nn.BatchNorm2d(interClannels)
self.conv2 = nn.Conv2d(interClannels,growthRate,kernel_size=3,padding=1,bias=False)
def forward(self,x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat((x,out),1)
return out
class Denseblock(nn.Module):
def __init__(self,nChannels,growthRate, nDenseBlocks):
super(Denseblock, self).__init__()
layers = []
for i in range(nDenseBlocks):
layers.append(Bottleneck(nChannels,growthRate))
nChannels +=growthRate
self.denseblock = nn.Sequential(*layers)
def forward(self,x):
return self.denseblock(x)
class transition(nn.Module):
def __init__(self,nChannels,outChannels):
super(transition, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,outChannels, kernel_size=1, bias=False)
def forward(self,x):
out = self.conv1(F.relu(self.bn1(x)))
out = F.avg_pool2d(out,2)
return out
class DenseNet(nn.Module):
def __init__(self,blocks,num_classes=3,growthRate =32,beta=0.5):
super(DenseNet, self).__init__()
self.conv1 = Conv1(3,64)
self.layer1 = Denseblock(64,growthRate,blocks[0])
Cin1 = int(64+growthRate*blocks[0])
self.trans1 = transition(int(Cin1),int(Cin1*beta))
self.layer2 = Denseblock(int(Cin1*beta),32,blocks[1])
Cin2 = int(Cin1*beta+growthRate*blocks[1])
self.trans2 = transition(Cin2,int(Cin2*beta))
self.layer3 = Denseblock(int(Cin2*beta), 32, blocks[2])
Cin3 = int(Cin2 * beta + growthRate * blocks[2])
self.trans3 = transition(Cin3,int(Cin3*beta))
self.layer4 = Denseblock(int(Cin3*beta), 32, blocks[3])
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(1024, num_classes)
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")
elif isinstance(m,nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
def forward(self,x):
x = self.conv1(x)
x = self.layer1(x)
x = self.trans1(x)
x = self.layer2(x)
x = self.trans2(x)
x = self.layer3(x)
x = self.trans3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0),-1)
x = self.fc(x)
return x
def DenseNet121():
return DenseNet([6,12,24,16])
if __name__=='__main__':
model = DenseNet121()
print(model)
input = torch.randn(1, 3, 224, 224)
out = model(input)
print(out.shape)