import _init_paths
import torch
import torch. nn as nn
from layers import unetConv2, unetUp
from utils import init_weights, count_param
class UNet_Nested ( nn. Module) :
def __init__ ( self, in_channels= 3 , n_classes= 2 , feature_scale= 2 , is_deconv= True , is_batchnorm= True , is_ds= True ) :
super ( UNet_Nested, self) . __init__( )
self. in_channels = in_channels
self. feature_scale = feature_scale
self. is_deconv = is_deconv
self. is_batchnorm = is_batchnorm
self. is_ds = is_ds
filters = [ 64 , 128 , 256 , 512 , 1024 ]
filters = [ int ( x / self. feature_scale) for x in filters]
self. maxpool = nn. MaxPool2d( kernel_size= 2 )
self. conv00 = unetConv2( self. in_channels, filters[ 0 ] , self. is_batchnorm)
self. conv10 = unetConv2( filters[ 0 ] , filters[ 1 ] , self. is_batchnorm)
self. conv20 = unetConv2( filters[ 1 ] , filters[ 2 ] , self. is_batchnorm)
self. conv30 = unetConv2( filters[ 2 ] , filters[ 3 ] , self. is_batchnorm)
self. conv40 = unetConv2( filters[ 3 ] , filters[ 4 ] , self. is_batchnorm)
self. up_concat01 = unetUp( filters[ 1 ] , filters[ 0 ] , self. is_deconv)
self. up_concat11 = unetUp( filters[ 2 ] , filters[ 1 ] , self. is_deconv)
self. up_concat21 = unetUp( filters[ 3 ] , filters[ 2 ] , self. is_deconv)
self. up_concat31 = unetUp( filters[ 4 ] , filters[ 3 ] , self. is_deconv)
self. up_concat02 = unetUp( filters[ 1 ] , filters[ 0 ] , self. is_deconv, 3 )
self. up_concat12 = unetUp( filters[ 2 ] , filters[ 1 ] , self. is_deconv, 3 )
self. up_concat22 = unetUp( filters[ 3 ] , filters[ 2 ] , self. is_deconv, 3 )
self. up_concat03 = unetUp( filters[ 1 ] , filters[ 0 ] , self. is_deconv, 4 )
self. up_concat13 = unetUp( filters[ 2 ] , filters[ 1 ] , self. is_deconv, 4 )
self. up_concat04 = unetUp( filters[ 1 ] , filters[ 0 ] , self. is_deconv, 5 )
self. final_1 = nn. Conv2d( filters[ 0 ] , n_classes, 1 )
self. final_2 = nn. Conv2d( filters[ 0 ] , n_classes, 1 )
self. final_3 = nn. Conv2d( filters[ 0 ] , n_classes, 1 )
self. final_4 = nn. Conv2d( filters[ 0 ] , n_classes, 1 )
for m in self. modules( ) :
if isinstance ( m, nn. Conv2d) :
init_weights( m, init_type= 'kaiming' )
elif isinstance ( m, nn. BatchNorm2d) :
init_weights( m, init_type= 'kaiming' )
def forward ( self, inputs) :
X_00 = self. conv00( inputs)
maxpool0 = self. maxpool( X_00)
X_10= self. conv10( maxpool0)
maxpool1 = self. maxpool( X_10)
X_20 = self. conv20( maxpool1)
maxpool2 = self. maxpool( X_20)
X_30 = self. conv30( maxpool2)
maxpool3 = self. maxpool( X_30)
X_40 = self. conv40( maxpool3)
X_01 = self. up_concat01( X_10, X_00)
X_11 = self. up_concat11( X_20, X_10)
X_21 = self. up_concat21( X_30, X_20)
X_31 = self. up_concat31( X_40, X_30)
X_02 = self. up_concat02( X_11, X_00, X_01)
X_12 = self. up_concat12( X_21, X_10, X_11)
X_22 = self. up_concat22( X_31, X_20, X_21)
X_03 = self. up_concat03( X_12, X_00, X_01, X_02)
X_13 = self. up_concat13( X_22, X_10, X_11, X_12)
X_04 = self. up_concat04( X_13, X_00, X_01, X_02, X_03)
final_1 = self. final_1( X_01)
final_2 = self. final_2( X_02)
final_3 = self. final_3( X_03)
final_4 = self. final_4( X_04)
final = ( final_1+ final_2+ final_3+ final_4) / 4
if self. is_ds:
return final
else :
return final_4
print ( '#### Test Case ###' )
from torch. autograd import Variable
x = Variable( torch. rand( 1 , 3 , 512 , 512 ) )
model = UNet_Nested( )
param = count_param( model)
y = model( x)
print ( 'Output shape:' , y. shape)
print ( 'UNet++ totoal parameters: %.2fM (%d)' % ( param/ 1e6 , param) )
model. eval ( )
trace = torch. jit. trace( model, torch. randn( 1 , 3 , 512 , 512 ) )
torch. jit. save( trace, 'UNet++_model.pt' )