目录
模型:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
class UNet3D(nn.Module):
"""
Baseline model for pulmonary airway segmentation
"""
def __init__(self, in_channels=1, out_channels=1, coord=True):
"""
:param in_channels: input channel numbers
:param out_channels: output channel numbers
:param coord: boolean, True=Use coordinates as position information, False=not
"""
super(UNet3D, self).__init__()
self._in_channels = in_channels
self._out_channels = out_channels
self._coord = coord
self.pooling = nn.MaxPool3d(kernel_size=(2, 2, 2))
self.upsampling = nn.Upsample(scale_factor=2)
self.conv1 = nn.Sequential(
nn.Conv3d(in_channels=self._in_channels, out_channels=8, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(8),
nn.ReLU(inplace=True),
nn.Conv3d(8, 16, 3, 1, 1),
nn.InstanceNorm3d(16),
nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(
nn.Conv3d(16, 16, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(16),
nn.ReLU(inplace=True),
nn.Conv3d(16, 32, 3, 1, 1),
nn.InstanceNorm3d(32),
nn.ReLU(inplace=True))
self.conv3 = nn.Sequential(
nn.Conv3d(32, 32, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(32, 64, 3, 1, 1),
nn.InstanceNorm3d(64),
nn.ReLU(inplace=True))
self.conv4 = nn.Sequential(
nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(64),
nn.ReLU(inplace=True),
nn.Conv3d(64, 128, 3, 1, 1),
nn.InstanceNorm3d(128),
nn.ReLU(inplace=True))
self.conv5 = nn.Sequential(
nn.Conv3d(128, 128, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(128),
nn.ReLU(inplace=True),
nn.Conv3d(128, 256, 3, 1, 1),
nn.InstanceNorm3d(256),
nn.ReLU(inplace=True))
self.conv6 = nn.Sequential(
nn.Conv3d(256 + 128, 128, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(128),
nn.ReLU(inplace=True),
nn.Conv3d(128, 128, 3, 1, 1),
nn.InstanceNorm3d(128),
nn.ReLU(inplace=True))
self.conv7 = nn.Sequential(
nn.Conv3d(128 + 64, 64, 3, 1, 1),
nn.InstanceNorm3d(64),
nn.ReLU(inplace=True),
nn.Conv3d(64, 64, 3, 1, 1),
nn.InstanceNorm3d(64),
nn.ReLU(inplace=True))
self.conv8 = nn.Sequential(
nn.Conv3d(64 + 32, 32, 3, 1, 1),
nn.InstanceNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(32, 32, 3, 1, 1),
nn.InstanceNorm3d(32),
nn.ReLU(inplace=True))
if not self._coord:
num_channel_coord = 3
else:
num_channel_coord = 0
self.conv9 = nn.Sequential(
nn.Conv3d(32 + 16 + num_channel_coord, 16, 3, 1, 1),
nn.InstanceNorm3d(16),
nn.ReLU(inplace=True),
nn.Conv3d(16, 16, 3, 1, 1),
nn.InstanceNorm3d(16),
nn.ReLU(inplace=True))
self.sigmoid = nn.Sigmoid()
self.conv10 = nn.Conv3d(16, self._out_channels, 1, 1, 0)
def forward(self, input, coordmap=None):
"""
:param input: shape = (batch_size, num_channels, D, H, W) \
:param coordmap: shape = (batch_size, 3, D, H, W)
:return: output segmentation tensor, attention mapping
"""
conv1 = self.conv1(input)
x = self.pooling(conv1)
conv2 = self.conv2(x)
x = self.pooling(conv2)
conv3 = self.conv3(x)
x = self.pooling(conv3)
conv4 = self.conv4(x)
x = self.pooling(conv4)
conv5 = self.conv5(x)
x = self.upsampling(conv5)
x = torch.cat([x, conv4], dim=1)
conv6 = self.conv6(x)
x = self.upsampling(conv6)
x = torch.cat([x, conv3], dim=1)
conv7 = self.conv7(x)
x = self.upsampling(conv7)
x = torch.cat([x, conv2], dim=1)
conv8 = self.conv8(x)
x = self.upsampling(conv8)
if self._coord and (coordmap is not None):
x = torch.cat([x, conv1, coordmap], dim=1)
else:
x = torch.cat([x, conv1], dim=1)
conv9 = self.conv9(x)
x = self.conv10(conv9)
x = self.sigmoid(x)
return x
if __name__ == '__main__':
net = UNet3D(in_channels=1, out_channels=1)
print(net)
print('Number of network parameters:', sum(param.numel() for param in net.parameters()))
# Number of network parameters: 4118849 Baseline
计算模型参数的代码:
import torch
import torch.nn
from torchsummary import summary
from importlib import import_module
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = import_module('baseline')
config, net = model.get_model()
model = net.to(device)
summary(model, input_size=(1,64, 128, 128))
手算参数量结果:
UNet3D(
(pooling): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
(upsampling): Upsample(scale_factor=2.0, mode=nearest)
(conv1): Sequential(
######3*3*3+8 = 224
(0): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
###### 0
(1): InstanceNorm3d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(2): ReLU(inplace=True)
######3*3*3*8*16 + 16 = 3472
(3): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
(conv2): Sequential(
######3*3*3*16*16+16=6928
(0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(1): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
#0
(2): ReLU(inplace=True)
######3*3*3*16*32 + 32 = 13856
(3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
(conv3): Sequential(
######3*3*3*32*32+32 = 27680
(0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(2): ReLU(inplace=True)
######3*3*3*32*64+64 =55360
(3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
(conv4): Sequential(
######3*3*3*64*64+64 =110656
(0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(2): ReLU(inplace=True)
######3*3*3*64*128+128 = 221312
(3): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
(conv5): Sequential(
######3*3*3*128*128+128 = 442496
(0): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(2): ReLU(inplace=True)
######3*3*3*128*256+256 = 884992
(3): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
(conv6): Sequential(
######3*3*3*384*128+128 = 1327232
(0): Conv3d(384, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(2): ReLU(inplace=True)
######3*3*3*128*128+128 = 442496
(3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
(conv7): Sequential(
######3*3*3*192*64+64 = 331840
(0): Conv3d(192, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(2): ReLU(inplace=True)
#####3*3*3*64*64+64 = 110656
(3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
(conv8): Sequential(
#####3*3*3*96*32+32 = 82976
(0): Conv3d(96, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(2): ReLU(inplace=True)
#####3*3*3*32*32+32 = 27680
(3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
(conv9): Sequential(
#####3*3*3*48*16+16 = 20752
(0): Conv3d(48, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(1): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(2): ReLU(inplace=True)
#####3*3*3*16*16+16 = 2928
(3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
(4): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
(5): ReLU(inplace=True)
)
######0
(sigmoid): Sigmoid()
#####1*1*1*16*1+1 = 17
(conv10): Conv3d(16, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)
python代码计算结果:
输入大小:(1,64, 128, 128)
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv3d-1 [-1, 8, 64, 128, 128] 224
InstanceNorm3d-2 [-1, 8, 64, 128, 128] 0
ReLU-3 [-1, 8, 64, 128, 128] 0
Conv3d-4 [-1, 16, 64, 128, 128] 3,472
InstanceNorm3d-5 [-1, 16, 64, 128, 128] 0
ReLU-6 [-1, 16, 64, 128, 128] 0
MaxPool3d-7 [-1, 16, 32, 64, 64] 0
Conv3d-8 [-1, 16, 32, 64, 64] 6,928
InstanceNorm3d-9 [-1, 16, 32, 64, 64] 0
ReLU-10 [-1, 16, 32, 64, 64] 0
Conv3d-11 [-1, 32, 32, 64, 64] 13,856
InstanceNorm3d-12 [-1, 32, 32, 64, 64] 0
ReLU-13 [-1, 32, 32, 64, 64] 0
MaxPool3d-14 [-1, 32, 16, 32, 32] 0
Conv3d-15 [-1, 32, 16, 32, 32] 27,680
InstanceNorm3d-16 [-1, 32, 16, 32, 32] 0
ReLU-17 [-1, 32, 16, 32, 32] 0
Conv3d-18 [-1, 64, 16, 32, 32] 55,360
InstanceNorm3d-19 [-1, 64, 16, 32, 32] 0
ReLU-20 [-1, 64, 16, 32, 32] 0
MaxPool3d-21 [-1, 64, 8, 16, 16] 0
Conv3d-22 [-1, 64, 8, 16, 16] 110,656
InstanceNorm3d-23 [-1, 64, 8, 16, 16] 0
ReLU-24 [-1, 64, 8, 16, 16] 0
Conv3d-25 [-1, 128, 8, 16, 16] 221,312
InstanceNorm3d-26 [-1, 128, 8, 16, 16] 0
ReLU-27 [-1, 128, 8, 16, 16] 0
MaxPool3d-28 [-1, 128, 4, 8, 8] 0
Conv3d-29 [-1, 128, 4, 8, 8] 442,496
InstanceNorm3d-30 [-1, 128, 4, 8, 8] 0
ReLU-31 [-1, 128, 4, 8, 8] 0
Conv3d-32 [-1, 256, 4, 8, 8] 884,992
InstanceNorm3d-33 [-1, 256, 4, 8, 8] 0
ReLU-34 [-1, 256, 4, 8, 8] 0
Upsample-35 [-1, 256, 8, 16, 16] 0
Conv3d-36 [-1, 128, 8, 16, 16] 1,327,232
InstanceNorm3d-37 [-1, 128, 8, 16, 16] 0
ReLU-38 [-1, 128, 8, 16, 16] 0
Conv3d-39 [-1, 128, 8, 16, 16] 442,496
InstanceNorm3d-40 [-1, 128, 8, 16, 16] 0
ReLU-41 [-1, 128, 8, 16, 16] 0
Upsample-42 [-1, 128, 16, 32, 32] 0
Conv3d-43 [-1, 64, 16, 32, 32] 331,840
InstanceNorm3d-44 [-1, 64, 16, 32, 32] 0
ReLU-45 [-1, 64, 16, 32, 32] 0
Conv3d-46 [-1, 64, 16, 32, 32] 110,656
InstanceNorm3d-47 [-1, 64, 16, 32, 32] 0
ReLU-48 [-1, 64, 16, 32, 32] 0
Upsample-49 [-1, 64, 32, 64, 64] 0
Conv3d-50 [-1, 32, 32, 64, 64] 82,976
InstanceNorm3d-51 [-1, 32, 32, 64, 64] 0
ReLU-52 [-1, 32, 32, 64, 64] 0
Conv3d-53 [-1, 32, 32, 64, 64] 27,680
InstanceNorm3d-54 [-1, 32, 32, 64, 64] 0
ReLU-55 [-1, 32, 32, 64, 64] 0
Upsample-56 [-1, 32, 64, 128, 128] 0
Conv3d-57 [-1, 16, 64, 128, 128] 20,752
InstanceNorm3d-58 [-1, 16, 64, 128, 128] 0
ReLU-59 [-1, 16, 64, 128, 128] 0
Conv3d-60 [-1, 16, 64, 128, 128] 6,928
InstanceNorm3d-61 [-1, 16, 64, 128, 128] 0
ReLU-62 [-1, 16, 64, 128, 128] 0
Conv3d-63 [-1, 1, 64, 128, 128] 17
Sigmoid-64 [-1, 1, 64, 128, 128] 0
================================================================