神经网络参数量计算-UNet3D为例

目录

模型:

计算模型参数的代码:

 手算参数量结果:

 python代码计算结果:


模型:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os


class UNet3D(nn.Module):
	"""
	Baseline model for pulmonary airway segmentation
	"""
	def __init__(self, in_channels=1, out_channels=1, coord=True):
		"""
		:param in_channels: input channel numbers
		:param out_channels: output channel numbers
		:param coord: boolean, True=Use coordinates as position information, False=not
		"""
		super(UNet3D, self).__init__()
		self._in_channels = in_channels
		self._out_channels = out_channels
		self._coord = coord
		self.pooling = nn.MaxPool3d(kernel_size=(2, 2, 2))
		self.upsampling = nn.Upsample(scale_factor=2)
		self.conv1 = nn.Sequential(
			nn.Conv3d(in_channels=self._in_channels, out_channels=8, kernel_size=3, stride=1, padding=1),
			nn.InstanceNorm3d(8),
			nn.ReLU(inplace=True),
			nn.Conv3d(8, 16, 3, 1, 1),
			nn.InstanceNorm3d(16),
			nn.ReLU(inplace=True))
		
		self.conv2 = nn.Sequential(
			nn.Conv3d(16, 16, kernel_size=3, stride=1, padding=1),
			nn.InstanceNorm3d(16),
			nn.ReLU(inplace=True),
			nn.Conv3d(16, 32, 3, 1, 1),
			nn.InstanceNorm3d(32),
			nn.ReLU(inplace=True))

		self.conv3 = nn.Sequential(
			nn.Conv3d(32, 32, kernel_size=3, stride=1, padding=1),
			nn.InstanceNorm3d(32),
			nn.ReLU(inplace=True),
			nn.Conv3d(32, 64, 3, 1, 1),
			nn.InstanceNorm3d(64),
			nn.ReLU(inplace=True))
	
		self.conv4 = nn.Sequential(
			nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1),
			nn.InstanceNorm3d(64),
			nn.ReLU(inplace=True),
			nn.Conv3d(64, 128, 3, 1, 1),
			nn.InstanceNorm3d(128),
			nn.ReLU(inplace=True))

		self.conv5 = nn.Sequential(
			nn.Conv3d(128, 128, kernel_size=3, stride=1, padding=1),
			nn.InstanceNorm3d(128),
			nn.ReLU(inplace=True),
			nn.Conv3d(128, 256, 3, 1, 1),
			nn.InstanceNorm3d(256),
			nn.ReLU(inplace=True))

		self.conv6 = nn.Sequential(
			nn.Conv3d(256 + 128, 128, kernel_size=3, stride=1, padding=1),
			nn.InstanceNorm3d(128),
			nn.ReLU(inplace=True),
			nn.Conv3d(128, 128, 3, 1, 1),
			nn.InstanceNorm3d(128),
			nn.ReLU(inplace=True))
		
		self.conv7 = nn.Sequential(
			nn.Conv3d(128 + 64, 64, 3, 1, 1),
			nn.InstanceNorm3d(64),
			nn.ReLU(inplace=True),
			nn.Conv3d(64, 64, 3, 1, 1),
			nn.InstanceNorm3d(64),
			nn.ReLU(inplace=True))
		
		self.conv8 = nn.Sequential(
			nn.Conv3d(64 + 32, 32, 3, 1, 1),
			nn.InstanceNorm3d(32),
			nn.ReLU(inplace=True),
			nn.Conv3d(32, 32, 3, 1, 1),
			nn.InstanceNorm3d(32),
			nn.ReLU(inplace=True))
		
		if not self._coord:
			num_channel_coord = 3
		else:
			num_channel_coord = 0
		self.conv9 = nn.Sequential(
			nn.Conv3d(32 + 16 + num_channel_coord, 16, 3, 1, 1),
			nn.InstanceNorm3d(16),
			nn.ReLU(inplace=True),
			nn.Conv3d(16, 16, 3, 1, 1),
			nn.InstanceNorm3d(16),
			nn.ReLU(inplace=True))
	
		self.sigmoid = nn.Sigmoid()
		self.conv10 = nn.Conv3d(16, self._out_channels, 1, 1, 0)

	def forward(self, input, coordmap=None):
		"""
		:param input: shape = (batch_size, num_channels, D, H, W) \
		:param coordmap: shape = (batch_size, 3, D, H, W)
		:return: output segmentation tensor, attention mapping
		"""
		conv1 = self.conv1(input)
		x = self.pooling(conv1)
		
		conv2 = self.conv2(x)
		x = self.pooling(conv2)
		
		conv3 = self.conv3(x)
		x = self.pooling(conv3)
		
		conv4 = self.conv4(x)
		x = self.pooling(conv4)

		conv5 = self.conv5(x)

		x = self.upsampling(conv5)
		x = torch.cat([x, conv4], dim=1)
		conv6 = self.conv6(x)
		
		x = self.upsampling(conv6)
		x = torch.cat([x, conv3], dim=1)
		conv7 = self.conv7(x)
		
		x = self.upsampling(conv7)
		x = torch.cat([x, conv2], dim=1)
		conv8 = self.conv8(x)
		
		x = self.upsampling(conv8)

		if self._coord and (coordmap is not None):
			x = torch.cat([x, conv1, coordmap], dim=1)
		else:
			x = torch.cat([x, conv1], dim=1)

		conv9 = self.conv9(x)
		
		x = self.conv10(conv9)

		x = self.sigmoid(x)	
		

		return x

if __name__ == '__main__':
	net = UNet3D(in_channels=1, out_channels=1)
	print(net)
	print('Number of network parameters:', sum(param.numel() for param in net.parameters()))
# Number of network parameters: 4118849 Baseline

计算模型参数的代码:

import torch
import torch.nn
from torchsummary import summary
from importlib import import_module

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = import_module('baseline')
config, net = model.get_model()
model = net.to(device)
summary(model, input_size=(1,64, 128, 128))

 手算参数量结果:

UNet3D(
  (pooling): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (upsampling): Upsample(scale_factor=2.0, mode=nearest)
  (conv1): Sequential(
     ######3*3*3+8 = 224
    (0): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    ###### 0     
(1): InstanceNorm3d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
     ######0
    (2): ReLU(inplace=True)
     ######3*3*3*8*16 + 16 = 3472 
    (3): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
     ######0    
(4): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
     ######0
    (5): ReLU(inplace=True)
  )
  (conv2): Sequential(
     ######3*3*3*16*16+16=6928
    (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      ######0
(1): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
     #0  
(2): ReLU(inplace=True)
     ######3*3*3*16*32 + 32 = 13856  
(3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    ######0
(4): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
   ######0
    (5): ReLU(inplace=True)
  )
  (conv3): Sequential(
 ######3*3*3*32*32+32 = 27680
    (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
   ######0
    (1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
   ######0   
 (2): ReLU(inplace=True)
   ######3*3*3*32*64+64 =55360
    (3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0  
    (4): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0  
    (5): ReLU(inplace=True)
  )
  (conv4): Sequential(
   ######3*3*3*64*64+64 =110656
    (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0  
    (1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0  
    (2): ReLU(inplace=True)
 ######3*3*3*64*128+128 = 221312
    (3): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
    (4): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
    (5): ReLU(inplace=True)
  )
  (conv5): Sequential(
######3*3*3*128*128+128 = 442496
    (0): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
    (1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
    (2): ReLU(inplace=True)
######3*3*3*128*256+256 = 884992
    (3): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
    (4): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
    (5): ReLU(inplace=True)
  )
  (conv6): Sequential(
######3*3*3*384*128+128 = 1327232
    (0): Conv3d(384, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
    (1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
    (2): ReLU(inplace=True)
######3*3*3*128*128+128 = 442496
    (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
    (4): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
    (5): ReLU(inplace=True)
  )
  (conv7): Sequential(
######3*3*3*192*64+64 = 331840
    (0): Conv3d(192, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
    (1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
    (2): ReLU(inplace=True)
#####3*3*3*64*64+64 = 110656
    (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0
    (4): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0
    (5): ReLU(inplace=True)
  )
  (conv8): Sequential(
#####3*3*3*96*32+32 = 82976
    (0): Conv3d(96, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0    
(1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0 
   (2): ReLU(inplace=True)
#####3*3*3*32*32+32 = 27680
    (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0 
   (4): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0  
  (5): ReLU(inplace=True)
  )
  (conv9): Sequential(
#####3*3*3*48*16+16 = 20752
    (0): Conv3d(48, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0  
  (1): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0  
  (2): ReLU(inplace=True)
#####3*3*3*16*16+16 = 2928
    (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
######0  
    (4): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
######0  
    (5): ReLU(inplace=True)
  )
######0  
  (sigmoid): Sigmoid()
#####1*1*1*16*1+1 = 17
  (conv10): Conv3d(16, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)

 python代码计算结果:

输入大小:(1,64, 128, 128)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1      [-1, 8, 64, 128, 128]             224
    InstanceNorm3d-2      [-1, 8, 64, 128, 128]               0
              ReLU-3      [-1, 8, 64, 128, 128]               0
            Conv3d-4     [-1, 16, 64, 128, 128]           3,472
    InstanceNorm3d-5     [-1, 16, 64, 128, 128]               0
              ReLU-6     [-1, 16, 64, 128, 128]               0
         MaxPool3d-7       [-1, 16, 32, 64, 64]               0
            Conv3d-8       [-1, 16, 32, 64, 64]           6,928
    InstanceNorm3d-9       [-1, 16, 32, 64, 64]               0
             ReLU-10       [-1, 16, 32, 64, 64]               0
           Conv3d-11       [-1, 32, 32, 64, 64]          13,856
   InstanceNorm3d-12       [-1, 32, 32, 64, 64]               0
             ReLU-13       [-1, 32, 32, 64, 64]               0
        MaxPool3d-14       [-1, 32, 16, 32, 32]               0
           Conv3d-15       [-1, 32, 16, 32, 32]          27,680
   InstanceNorm3d-16       [-1, 32, 16, 32, 32]               0
             ReLU-17       [-1, 32, 16, 32, 32]               0
           Conv3d-18       [-1, 64, 16, 32, 32]          55,360
   InstanceNorm3d-19       [-1, 64, 16, 32, 32]               0
             ReLU-20       [-1, 64, 16, 32, 32]               0
        MaxPool3d-21        [-1, 64, 8, 16, 16]               0
           Conv3d-22        [-1, 64, 8, 16, 16]         110,656
   InstanceNorm3d-23        [-1, 64, 8, 16, 16]               0
             ReLU-24        [-1, 64, 8, 16, 16]               0
           Conv3d-25       [-1, 128, 8, 16, 16]         221,312
   InstanceNorm3d-26       [-1, 128, 8, 16, 16]               0
             ReLU-27       [-1, 128, 8, 16, 16]               0
        MaxPool3d-28         [-1, 128, 4, 8, 8]               0
           Conv3d-29         [-1, 128, 4, 8, 8]         442,496
   InstanceNorm3d-30         [-1, 128, 4, 8, 8]               0
             ReLU-31         [-1, 128, 4, 8, 8]               0
           Conv3d-32         [-1, 256, 4, 8, 8]         884,992
   InstanceNorm3d-33         [-1, 256, 4, 8, 8]               0
             ReLU-34         [-1, 256, 4, 8, 8]               0
         Upsample-35       [-1, 256, 8, 16, 16]               0
           Conv3d-36       [-1, 128, 8, 16, 16]       1,327,232
   InstanceNorm3d-37       [-1, 128, 8, 16, 16]               0
             ReLU-38       [-1, 128, 8, 16, 16]               0
           Conv3d-39       [-1, 128, 8, 16, 16]         442,496
   InstanceNorm3d-40       [-1, 128, 8, 16, 16]               0
             ReLU-41       [-1, 128, 8, 16, 16]               0
         Upsample-42      [-1, 128, 16, 32, 32]               0
           Conv3d-43       [-1, 64, 16, 32, 32]         331,840
   InstanceNorm3d-44       [-1, 64, 16, 32, 32]               0
             ReLU-45       [-1, 64, 16, 32, 32]               0
           Conv3d-46       [-1, 64, 16, 32, 32]         110,656
   InstanceNorm3d-47       [-1, 64, 16, 32, 32]               0
             ReLU-48       [-1, 64, 16, 32, 32]               0
         Upsample-49       [-1, 64, 32, 64, 64]               0
           Conv3d-50       [-1, 32, 32, 64, 64]          82,976
   InstanceNorm3d-51       [-1, 32, 32, 64, 64]               0
             ReLU-52       [-1, 32, 32, 64, 64]               0
           Conv3d-53       [-1, 32, 32, 64, 64]          27,680
   InstanceNorm3d-54       [-1, 32, 32, 64, 64]               0
             ReLU-55       [-1, 32, 32, 64, 64]               0
         Upsample-56     [-1, 32, 64, 128, 128]               0
           Conv3d-57     [-1, 16, 64, 128, 128]          20,752
   InstanceNorm3d-58     [-1, 16, 64, 128, 128]               0
             ReLU-59     [-1, 16, 64, 128, 128]               0
           Conv3d-60     [-1, 16, 64, 128, 128]           6,928
   InstanceNorm3d-61     [-1, 16, 64, 128, 128]               0
             ReLU-62     [-1, 16, 64, 128, 128]               0
           Conv3d-63      [-1, 1, 64, 128, 128]              17
          Sigmoid-64      [-1, 1, 64, 128, 128]               0
================================================================

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
U-Net、3D-Unet、ResU-Net和Attention U-Net是四种常用的U-Net神经网络结构,它们在医学图像分割等领域具有广泛的应用。下面我将为你介绍它们的算法原理。 1. U-Net[^1]: U-Net是一种用于图像分割的卷积神经网络结构。它由一个编码器和一个解码器组成,编码器用于提取图像的特征,解码器用于生成分割结果。U-Net的特点是将编码器和解码器之间的特征图进行跳跃连接,这样可以保留更多的低级和高级特征信息,提高分割的准确性。 2. 3D-Unet[^2]: 3D-Unet是U-Net的扩展版本,用于处理三维体积数据的分割任务。它在U-Net的基础上引入了3D卷积操作,以处理体积数据的空间信息。3D-Unet的编码器和解码器结构与U-Net类似,但在卷积操作中使用了3D卷积核。 3. ResU-Net[^3]: ResU-Net是在U-Net的基础上引入了残差连接的改进版本。残差连接可以帮助网络更好地传递梯度和特征信息,减轻了梯度消失的问题。ResU-Net的编码器和解码器结构与U-Net类似,但在跳跃连接中使用了残差连接。 4. Attention U-Net[^4]: Attention U-Net是在U-Net的基础上引入了注意力机制的改进版本。注意力机制可以帮助网络更好地关注重要的特征区域,提高分割的准确性。Attention U-Net的编码器和解码器结构与U-Net类似,但在跳跃连接中使用了注意力机制。 这些U-Net神经网络结构在医学图像分割等领域具有广泛的应用,它们通过不同的改进方式提高了分割的准确性和性能。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值