code
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBnReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, dilation=1):
super(ConvBnReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, dilation=dilation)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class FeatureNet(nn.Module):
def __init__(self, inner_channels):
super(FeatureNet, self).__init__()
self.convs, self.inners, self.outputs = [], [], []
for level in range(len(inner_channels)):
if level == 0:
self.convs.append(
ConvBnReLU(3, inner_channels[level], 3, 1, 1),)
self.convs.append(
ConvBnReLU(inner_channels[level], inner_channels[level], 3, 1, 1),)
else:
self.convs.append(ConvBnReLU(inner_channels[level-1], inner_channels[level], 5, 2, 2),)
self.convs.append(ConvBnReLU(inner_channels[level], inner_channels[level], 3, 1, 1),)
self.convs.append(ConvBnReLU(inner_channels[level], inner_channels[level], 3, 1, 1),)
self.outputs.append(nn.Conv2d(inner_channels[-1], inner_channels[level], 1, bias=False))
if level != len(inner_channels)-1:
self.inners.append(nn.Conv2d(inner_channels[level], inner_channels[-1], 1, bias=True))
self.inners.reverse()
self.outputs.reverse()
self.convs = nn.Sequential(*self.convs)
self.inners = nn.Sequential(*self.inners)
self.outputs = nn.Sequential(*self.outputs)
def forward(self, x) :
convx, outx = [], []
for i, layer in enumerate(self.convs.children()):
x = layer(x)
if i == 1 or i%3 ==1:
convx.append(x)
convx.reverse()
inner = convx[0]
for level, x in enumerate(convx):
if level == 0:
outx.append(self.outputs[level](x))
elif level != len(convx)-1:
inner = F.interpolate(inner, scale_factor=2.0, mode="bilinear", align_corners=False)\
+ self.inners[level-1](x)
outx.append(self.outputs[level](inner))
else:
outx.append(x)
outx.reverse()
return outx #conv1, f1, f2, f3
if __name__=="__main__":
x = torch.randint(1,255,(1,3,128,160)).float()
net1 = FeatureNet(inner_channels=[8, 16, 32, 64])
net2 = FeatureNet(inner_channels=[8, 16, 32, 64, 128])
# print(net)
outs1 = net1(x)
outs2 = net2(x)
for y in outs1:
print(y.shape)
for y in outs2:
print(y.shape)
result
torch.Size([1, 8, 128, 160])
torch.Size([1, 16, 64, 80])
torch.Size([1, 32, 32, 40])
torch.Size([1, 64, 16, 20])
torch.Size([1, 8, 128, 160])
torch.Size([1, 16, 64, 80])
torch.Size([1, 32, 32, 40])
torch.Size([1, 64, 16, 20])
torch.Size([1, 128, 8, 10])
注:
网络构建好后,测试__init__()中的层是否可以被print出来,如果不能,无法加载到gpu