环境:
torch 1.2.0
torchsummary 1.5.1
torchvision 0.4.0
opencv4.4
网络结构,incepton的unet,没什么可讲的。
import torch.nn as nn
import torch
class Inception(nn.Module):
def __init__(self, in_ch, out_ch):
super(Inception, self).__init__()
hide_ch = out_ch // 2
self.inception = nn.Sequential(
nn.Conv2d(in_ch, hide_ch, 1),
nn.BatchNorm2d(hide_ch),
nn.ReLU(inplace=True),
nn.Conv2d(hide_ch, hide_ch, 3, padding=1, groups=hide_ch),
nn.BatchNorm2d(hide_ch),
nn.ReLU(inplace=True),
nn.Conv2d(hide_ch, out_ch, 1)
)
def forward(self, x):
return self.inception(x)
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.doubleConv = nn.Sequential(
Inception(in_ch, out_ch),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
Inception(out_ch, out_ch),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.doubleConv(x)
class UNet(nn.Module):
def __init__(self, in_ch, out_ch):
super(UNet, self).__init__()
# down
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.Conv2d(64, 64, 2, 2, groups=64)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.Conv2d(128, 128, 2, 2, groups=128)
self.bottom = DoubleConv(128, 256)
# up
self.up3 = nn.ConvTranspose2d(256, 128, 2, 2)
self.conv3 = DoubleConv(128 * 2, 128)
self.up4 = nn.ConvTranspose2d(128, 64, 2, 2)
self.conv4 = DoubleConv(64 * 2, 64)
self.out = nn.Conv2d(64, out_ch, 1)
def forward(self, x):
# down
conv1 = self.conv1(x)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
bottom = self.bottom(pool2)
# up
up3 = self.up3(bottom)
merge3 = torch.cat([up3, conv2], dim=1)
conv3 = self.conv3(merge3)
up4 = self.up4(conv3)
merge4 = torch.cat([up4, conv1], dim=1)
conv4 = self.conv4(merge4)
out = self.out(conv4)
return nn.Sigmoid()(out)
if __name__ == '__main__':
net = UNet(1, 2)
inputs = torch.zeros((1, 1, 8, 8), dtype=torch.float32)
output = net(inputs)
print(output.size())
print(output)
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 160, 160] 64
BatchNorm2d-2 [-1, 32, 160, 160] 64
ReLU-3 [-1, 32, 160, 160] 0
Conv2d-4 [-1, 32, 160, 160] 320
BatchNorm2d-5 [-1, 32, 160, 160] 64
ReLU-6 [-1, 32, 160, 160] 0
Conv2d-7 [-1, 64, 160, 160] 2,112
Inception-8 [-1, 64, 160, 160] 0
BatchNorm2d-9 [-1, 64, 160, 160] 128
ReLU-10 [-1, 64, 160, 160] 0
Conv2d-11 [-1, 32, 160, 160] 2,080
BatchNorm2d-12 [-1, 32, 160, 160] 64
ReLU-13 [-1, 32, 160, 160] 0
Conv2d-14 [-1, 32, 160, 160] 320
BatchNorm2d-15 [-1, 32, 160, 160] 64
ReLU-16 [-1, 32, 160, 160] 0
Conv2d-17 [-1, 64, 160, 160] 2,112
Inception-18 [-1, 64, 160, 160] 0
BatchNorm2d-19 [-1, 64, 160, 160] 128
ReLU-20 [-1, 64, 160, 160] 0
DoubleConv-21 [-1, 64, 160, 160] 0
Conv2d-22 [-1, 64, 80, 80] 320
Conv2d-23 [-1, 64, 80, 80] 4,160
BatchNorm2d-24 [-1, 64, 80, 80] 128
ReLU-25 [-1, 64, 80, 80] 0
Conv2d-26 [-1, 64, 80, 80] 640
BatchNorm2d-27 [-1, 64, 80, 80] 128
ReLU-28 [-1, 64, 80, 80] 0
Conv2d-29 [-1, 128, 80, 80] 8,320
Inception-30 [-1, 128, 80, 80] 0
BatchNorm2d-31 [-1, 128, 80, 80] 256
ReLU-32 [-1, 128, 80, 80] 0
Conv2d-33 [-1, 64, 80, 80] 8,256
BatchNorm2d-34 [-1, 64, 80, 80] 128
ReLU-35 [-1, 64, 80, 80] 0
Conv2d-36 [-1, 64, 80, 80] 640
BatchNorm2d-37 [-1, 64, 80, 80] 128
ReLU-38 [-1, 64, 80, 80] 0
Conv2d-39 [-1, 128, 80, 80] 8,320
Inception-40 [-1, 128, 80, 80] 0
BatchNorm2d-41 [-1, 128, 80, 80] 256
ReLU-42 [-1, 128, 80, 80] 0
DoubleConv-43 [-1, 128, 80, 80] 0
Conv2d-44 [-1, 128, 40, 40] 640
Conv2d-45 [-1, 128, 40, 40] 16,512
BatchNorm2d-46 [-1, 128, 40, 40] 256
ReLU-47 [-1, 128, 40, 40] 0
Conv2d-48 [-1, 128, 40, 40] 1,280
BatchNorm2d-49 [-1, 128, 40, 40] 256
ReLU-50 [-1, 128, 40, 40] 0
Conv2d-51 [-1, 256, 40, 40] 33,024
Inception-52 [-1, 256, 40, 40] 0
BatchNorm2d-53 [-1, 256, 40, 40] 512
ReLU-54 [-1, 256, 40, 40] 0
Conv2d-55 [-1, 128, 40, 40] 32,896
BatchNorm2d-56 [-1, 128, 40, 40] 256
ReLU-57 [-1, 128, 40, 40] 0
Conv2d-58 [-1, 128, 40, 40] 1,280
BatchNorm2d-59 [-1, 128, 40, 40] 256
ReLU-60 [-1, 128, 40, 40] 0
Conv2d-61 [-1, 256, 40, 40] 33,024
Inception-62 [-1, 256, 40, 40] 0
BatchNorm2d-63 [-1, 256, 40, 40] 512
ReLU-64 [-1, 256, 40, 40] 0
DoubleConv-65 [-1, 256, 40, 40] 0
ConvTranspose2d-66 [-1, 128, 80, 80] 131,200
Conv2d-67 [-1, 64, 80, 80] 16,448
BatchNorm2d-68 [-1, 64, 80, 80] 128
ReLU-69 [-1, 64, 80, 80] 0
Conv2d-70 [-1, 64, 80, 80] 640
BatchNorm2d-71 [-1, 64, 80, 80] 128
ReLU-72 [-1, 64, 80, 80] 0
Conv2d-73 [-1, 128, 80, 80] 8,320
Inception-74 [-1, 128, 80, 80] 0
BatchNorm2d-75 [-1, 128, 80, 80] 256
ReLU-76 [-1, 128, 80, 80] 0
Conv2d-77 [-1, 64, 80, 80] 8,256
BatchNorm2d-78 [-1, 64, 80, 80] 128
ReLU-79 [-1, 64, 80, 80] 0
Conv2d-80 [-1, 64, 80, 80] 640
BatchNorm2d-81 [-1, 64, 80, 80] 128
ReLU-82 [-1, 64, 80, 80] 0
Conv2d-83 [-1, 128, 80, 80] 8,320
Inception-84 [-1, 128, 80, 80] 0
BatchNorm2d-85 [-1, 128, 80, 80] 256
ReLU-86 [-1, 128, 80, 80] 0
DoubleConv-87 [-1, 128, 80, 80] 0
ConvTranspose2d-88 [-1, 64, 160, 160] 32,832
Conv2d-89 [-1, 32, 160, 160] 4,128
BatchNorm2d-90 [-1, 32, 160, 160] 64
ReLU-91 [-1, 32, 160, 160] 0
Conv2d-92 [-1, 32, 160, 160] 320
BatchNorm2d-93 [-1, 32, 160, 160] 64
ReLU-94 [-1, 32, 160, 160] 0
Conv2d-95 [-1, 64, 160, 160] 2,112
Inception-96 [-1, 64, 160, 160] 0
BatchNorm2d-97 [-1, 64, 160, 160] 128
ReLU-98 [-1, 64, 160, 160] 0
Conv2d-99 [-1, 32, 160, 160] 2,080
BatchNorm2d-100 [-1, 32, 160, 160] 64
ReLU-101 [-1, 32, 160, 160] 0
Conv2d-102 [-1, 32, 160, 160] 320
BatchNorm2d-103 [-1, 32, 160, 160] 64
ReLU-104 [-1, 32, 160, 160] 0
Conv2d-105 [-1, 64, 160, 160] 2,112
Inception-106 [-1, 64, 160, 160] 0
BatchNorm2d-107 [-1, 64, 160, 160] 128
ReLU-108 [-1, 64, 160, 160] 0
DoubleConv-109 [-1, 64, 160, 160] 0
Conv2d-110 [-1, 2, 160, 160] 130
================================================================
Total params: 379,298
Trainable params: 379,298
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.10
Forward/backward pass size (MB): 633.20
Params size (MB): 1.45
Estimated Total Size (MB): 634.75
----------------------------------------------------------------
训练了190轮次,效果测试如下: