目录
模型:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
class UNet3D(nn.Module):
"""
Baseline model for pulmonary airway segmentation
"""
def __init__(self, in_channels=1, out_channels=1, coord=True):
"""
:param in_channels: input channel numbers
:param out_channels: output channel numbers
:param coord: boolean, True=Use coordinates as position information, False=not
"""
super(UNet3D, self).__init__()
self._in_channels = in_channels
self._out_channels = out_channels
self._coord = coord
self.pooling = nn.MaxPool3d(kernel_size=(2, 2, 2))
self.upsampling = nn.Upsample(scale_factor=2)
self.conv1 = nn.Sequential(
nn.Conv3d(in_channels=self._in_channels, out_channels=8, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(8),
nn.ReLU(inplace=True),
nn.Conv3d(8, 16, 3, 1, 1),
nn.InstanceNorm3d(16),
nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(
nn.Conv3d(16, 16, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(16),
nn.ReLU(inplace=True),
nn.Conv3d(16, 32, 3, 1, 1),
nn.InstanceNorm3d(32),
nn.ReLU(inplace=True))
self.conv3 = nn.Sequential(
nn.Conv3d(32, 32, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(32, 64, 3, 1, 1),
nn.InstanceNorm3d(64),
nn.ReLU(inplace=True))
self.conv4 = nn.Sequential(
nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(64),
nn.ReLU(inplace=True),
nn.Conv3d(64, 128, 3, 1, 1),
nn.InstanceNorm3d(128),
nn.ReLU(inplace=True))
self.conv5 = nn.Sequential(
nn.Conv3d(128, 128, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(128),
nn.ReLU(inplace=True),
nn.Conv3d(128, 256, 3, 1, 1),
nn.InstanceNorm3d(256),
nn.ReLU(inplace=True))
self.conv6 = nn.Sequential(
nn.Conv3d(256 + 128, 128, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(128),
nn.ReLU(inplace=True),
nn.Conv3d(128, 128, 3, 1, 1),
nn.InstanceNorm3d(128),
nn.ReLU(inplace=True))
self.conv7 = nn.Sequential(
nn.Conv3d(128 + 64, 64, 3, 1, 1),
nn.InstanceNorm3d(64),
nn.ReLU(inplace=True),
nn.Conv3d(64, 64, 3, 1, 1),
nn.InstanceNorm3d(64),
nn.ReLU(inplace=True))
self.conv8 = nn.Sequential(
nn.Conv3d(64 + 32, 32, 3, 1, 1),
nn.InstanceNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(32, 32, 3, 1, 1),
nn.InstanceNorm3d(32),
nn.ReLU(inplace=True))
if not self._coord:
num_channel_coord = 3
else:
num_channel_coord = 0
self.conv9 = nn.Sequential(
nn.Conv3d(32 + 16 + num_channel_coord, 16, 3, 1, 1),
nn.InstanceNorm3d(16),
nn.ReLU(inplace=True),
nn.Conv3d(16, 16, 3, 1, 1),
nn.InstanceNorm3d(16),
nn.ReLU(inplace=True))
self.sigmoid = nn.Sigmoid()
self.conv10 = nn.Conv3d(16, self._out_channels, 1, 1, 0)
def forward(self, input, coordmap=None):
"""
:param input: shape = (batch_size, num_channels, D, H, W) \
:param coordmap: shape = (batch_size, 3, D, H, W)
:return: output segmentation tensor, attention mapping
"""
conv1 = self.conv1(input)
x = self.pooling(conv1)
con