TensorLayer and TensorFlow implementation of U-Net image segmenation.
The u-net is convolutional network architecture for fast and precise segmentation of images. Up to now it has outperformed the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. It has won the Grand Challenge for Computer-Automated Detection of Caries in Bitewing Radiography at ISBI 2015, and it has won the Cell Tracking Challenge at ISBI 2015 on the two most challenging transmitted light microscopy categories (Phase contrast and DIC microscopy) by a large margin (See also our annoucement).
More details can be found in https://github.com/zsdonghao?tab=repositories
With Simplified Convolutional Layer APIs (TensorLayer 1.2.5)
def u_net_2d_64_1024_deconv(x, n_out=2):from tensorlayer.layers import *
nx = int(x._shape[1])
ny = int(x._shape[2])
nz = int(x._shape[3])
print(" * Input: size of image: %d %d %d" % (nx, ny, nz))
w_init = tf.truncated_normal_initializer(stddev=0.01)
b_init = tf.constant_initializer(value=0.0)
inputs = InputLayer(x, name='inputs')
conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1')
conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2')
pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1')
conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1')
conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2')
pool2 = MaxPool2d(conv2, (2, 2), padding='SAME', name='pool2')
conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1')
conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2')
pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3')
conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1')
conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2')
pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4')
conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1')
conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2')
print(" * After conv: %s" % conv5.outputs)
up4 = DeConv2d(conv5, 512, (3, 3), out_size = (nx/8, ny/8), strides = (2, 2),
padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv4')
up4 = ConcatLayer([up4, conv4], concat_dim=3, name='concat4')
conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv4_1')
conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv4_2')
up3 = DeConv2d(conv4, 256, (3, 3), out_size = (nx/4, ny/4), strides = (2, 2),
padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv3')
up3 = ConcatLayer([up3, conv3], concat_dim=3, name='concat3')
conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv3_1')
conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv3_2')
up2 = DeConv2d(conv3, 128, (3, 3), out_size = (nx/2, ny/2), strides = (2, 2),
padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv2')
up2 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat2')
conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv2_1')
conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv2_2')
up1 = DeConv2d(conv2, 64, (3, 3), out_size = (nx/1, ny/1), strides = (2, 2),
padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv1')
up1 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat1')
conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv1_1')
conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv1_2')
conv1 = Conv2d(conv1, n_out, (1, 1), act=None, name='uconv1')
print(" * Output: %s" % conv1.outputs)
outputs = tl.act.pixel_wise_softmax(conv1.outputs)
return conv1, outputs
With Professional Convolutional Layer APIs.
def u_net_2d_64_1024_deconv_pro(x, n_out=2):
""" 2-D U-Net for Image Segmentation.
Parameters
-----------
x : tensor or placeholder of input with shape of [batch_size, row, col, channel]
batch_size : int, batch size
n_out : int, number of output channel, default is 2 for foreground and background (binary segmentation)
Returns
--------
network : TensorLayer layer class with identity output
outputs : tensor, the output with pixel-wise softmax
Notes
-----
- Recommend to use Adam with learning rate of 1e-5
"""
batch_size = int(x._shape[0])
nx = int(x._shape[1])
ny = int(x._shape[2])
nz = int(x._shape[3])
print(" * Input: size of image: %d %d %d" % (nx, ny, nz))
## define initializer
w_init = tf.truncated_normal_initializer(stddev=0.01)
b_init = tf.constant_initializer(value=0.0)
## u-net model
# convolution
# with tf.device('\gpu:0'):
net_in = tl.layers.InputLayer(x, name='input')
conv1 = tl.layers.Conv2dLayer(net_in, act=tf.nn.relu,
shape=[3,3,nz,64], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv1')
conv2 = tl.layers.Conv2dLayer(conv1, act=tf.nn.relu,
shape=[3,3,64,64], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv2')
pool1 = tl.layers.PoolLayer(conv2, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME',
pool=tf.nn.max_pool, name='pool1')
conv3 = tl.layers.Conv2dLayer(pool1, act=tf.nn.relu,
shape=[3,3,64,128], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv3')
conv4 = tl.layers.Conv2dLayer(conv3, act=tf.nn.relu,
shape=[3,3,128,128], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv4')
pool2 = tl.layers.PoolLayer(conv4, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME',
pool=tf.nn.max_pool, name='pool2')
conv5 = tl.layers.Conv2dLayer(pool2, act=tf.nn.relu,
shape=[3,3,128,256], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv5')
conv6 = tl.layers.Conv2dLayer(conv5, act=tf.nn.relu,
shape=[3,3,256,256], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv6')
pool3 = tl.layers.PoolLayer(conv6, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME',
pool=tf.nn.max_pool, name='pool3')
conv7 = tl.layers.Conv2dLayer(pool3, act=tf.nn.relu,
shape=[3,3,256,512], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv7')
conv8 = tl.layers.Conv2dLayer(conv7, act=tf.nn.relu,
shape=[3,3,512,512], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv8')
# print(conv8.outputs) # (10, 30, 30, 512)
pool4 = tl.layers.PoolLayer(conv8, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME',
pool=tf.nn.max_pool,name='pool4')
conv9 = tl.layers.Conv2dLayer(pool4, act=tf.nn.relu,
shape=[3,3,512,1024], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv9')
conv10 = tl.layers.Conv2dLayer(conv9, act=tf.nn.relu,
shape=[3,3,1024,1024], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv10')
print(" * After conv: %s" % conv10.outputs) # (batch_size, 32, 32, 1024)
# deconvoluation
deconv1 = tl.layers.DeConv2dLayer(conv10, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,512,1024], strides=[1,2,2,1], output_shape=[batch_size,nx/8,ny/8,512],
padding='SAME', W_init=w_init, b_init=b_init, name='devcon1_1')
# print(deconv1.outputs) #(10, 30, 30, 512)
deconv1_2 = tl.layers.ConcatLayer([conv8, deconv1], concat_dim=3, name='concat1_2')
deconv1_3 = tl.layers.Conv2dLayer(deconv1_2, act=tf.nn.relu,
shape=[3,3,1024,512], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv1_3')
deconv1_4 = tl.layers.Conv2dLayer(deconv1_3, act=tf.nn.relu,
shape=[3,3,512,512], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv1_4')
deconv2 = tl.layers.DeConv2dLayer(deconv1_4, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,256,512], strides=[1,2,2,1], output_shape=[batch_size,nx/4,ny/4,256],
padding='SAME', W_init=w_init, b_init=b_init, name='devcon2_1')
deconv2_2 = tl.layers.ConcatLayer([conv6, deconv2], concat_dim=3, name='concat2_2')
deconv2_3 = tl.layers.Conv2dLayer(deconv2_2, act=tf.nn.relu,
shape=[3,3,512,256], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv2_3')
deconv2_4 = tl.layers.Conv2dLayer(deconv2_3, act=tf.nn.relu,
shape=[3,3,256,256], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv2_4')
deconv3 = tl.layers.DeConv2dLayer(deconv2_4, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,128,256], strides=[1,2,2,1], output_shape=[batch_size,nx/2,ny/2,128],
padding='SAME', W_init=w_init, b_init=b_init, name='devcon3_1')
deconv3_2 = tl.layers.ConcatLayer([conv4, deconv3], concat_dim=3, name='concat3_2')
deconv3_3 = tl.layers.Conv2dLayer(deconv3_2, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,256,128], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv3_3')
deconv3_4 = tl.layers.Conv2dLayer(deconv3_3, act=tf.nn.relu,
shape=[3,3,128,128], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv3_4')
deconv4 = tl.layers.DeConv2dLayer(deconv3_4, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,64,128], strides=[1,2,2,1], output_shape=[batch_size,nx,ny,64],
padding='SAME', W_init=w_init, b_init=b_init, name='devconv4_1')
deconv4_2 = tl.layers.ConcatLayer([conv2, deconv4], concat_dim=3, name='concat4_2')
deconv4_3 = tl.layers.Conv2dLayer(deconv4_2, act=tf.nn.relu,
shape=[3,3,128,64], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv4_3')
deconv4_4 = tl.layers.Conv2dLayer(deconv4_3, act=tf.nn.relu,
shape=[3,3,64,64], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv4_4')
network = tl.layers.Conv2dLayer(deconv4_4,
act=tf.identity,
shape=[1,1,64,n_out], # [0]:foreground prob; [1]:background prob
strides=[1,1,1,1],
padding='SAME',
W_init=w_init, b_init=b_init, name='conv4_5')
# compute the softmax output
print(" * Output: %s" % network.outputs)
outputs = tl.act.pixel_wise_softmax(network.outputs)
return network, outputs
#
With Simplified Convolutional Layer APIs
def u_net_2d_64_1024_deconv_pro(x, n_out=2):
""" 2-D U-Net for Image Segmentation.
Parameters
-----------
x : tensor or placeholder of input with shape of [batch_size, row, col, channel]
batch_size : int, batch size
n_out : int, number of output channel, default is 2 for foreground and background (binary segmentation)
Returns
--------
network : TensorLayer layer class with identity output
outputs : tensor, the output with pixel-wise softmax
Notes
-----
- Recommend to use Adam with learning rate of 1e-5
"""
batch_size = int(x._shape[0])
nx = int(x._shape[1])
ny = int(x._shape[2])
nz = int(x._shape[3])
print(" * Input: size of image: %d %d %d" % (nx, ny, nz))
## define initializer
w_init = tf.truncated_normal_initializer(stddev=0.01)
b_init = tf.constant_initializer(value=0.0)
## u-net model
# convolution
# with tf.device('\gpu:0'):
net_in = tl.layers.InputLayer(x, name='input')
conv1 = tl.layers.Conv2dLayer(net_in, act=tf.nn.relu,
shape=[3,3,nz,64], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv1')
conv2 = tl.layers.Conv2dLayer(conv1, act=tf.nn.relu,
shape=[3,3,64,64], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv2')
pool1 = tl.layers.PoolLayer(conv2, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME',
pool=tf.nn.max_pool, name='pool1')
conv3 = tl.layers.Conv2dLayer(pool1, act=tf.nn.relu,
shape=[3,3,64,128], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv3')
conv4 = tl.layers.Conv2dLayer(conv3, act=tf.nn.relu,
shape=[3,3,128,128], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv4')
pool2 = tl.layers.PoolLayer(conv4, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME',
pool=tf.nn.max_pool, name='pool2')
conv5 = tl.layers.Conv2dLayer(pool2, act=tf.nn.relu,
shape=[3,3,128,256], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv5')
conv6 = tl.layers.Conv2dLayer(conv5, act=tf.nn.relu,
shape=[3,3,256,256], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv6')
pool3 = tl.layers.PoolLayer(conv6, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME',
pool=tf.nn.max_pool, name='pool3')
conv7 = tl.layers.Conv2dLayer(pool3, act=tf.nn.relu,
shape=[3,3,256,512], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv7')
conv8 = tl.layers.Conv2dLayer(conv7, act=tf.nn.relu,
shape=[3,3,512,512], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv8')
# print(conv8.outputs) # (10, 30, 30, 512)
pool4 = tl.layers.PoolLayer(conv8, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME',
pool=tf.nn.max_pool,name='pool4')
conv9 = tl.layers.Conv2dLayer(pool4, act=tf.nn.relu,
shape=[3,3,512,1024], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv9')
conv10 = tl.layers.Conv2dLayer(conv9, act=tf.nn.relu,
shape=[3,3,1024,1024], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv10')
print(" * After conv: %s" % conv10.outputs) # (batch_size, 32, 32, 1024)
# deconvoluation
deconv1 = tl.layers.DeConv2dLayer(conv10, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,512,1024], strides=[1,2,2,1], output_shape=[batch_size,nx/8,ny/8,512],
padding='SAME', W_init=w_init, b_init=b_init, name='devcon1_1')
# print(deconv1.outputs) #(10, 30, 30, 512)
deconv1_2 = tl.layers.ConcatLayer([conv8, deconv1], concat_dim=3, name='concat1_2')
deconv1_3 = tl.layers.Conv2dLayer(deconv1_2, act=tf.nn.relu,
shape=[3,3,1024,512], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv1_3')
deconv1_4 = tl.layers.Conv2dLayer(deconv1_3, act=tf.nn.relu,
shape=[3,3,512,512], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv1_4')
deconv2 = tl.layers.DeConv2dLayer(deconv1_4, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,256,512], strides=[1,2,2,1], output_shape=[batch_size,nx/4,ny/4,256],
padding='SAME', W_init=w_init, b_init=b_init, name='devcon2_1')
deconv2_2 = tl.layers.ConcatLayer([conv6, deconv2], concat_dim=3, name='concat2_2')
deconv2_3 = tl.layers.Conv2dLayer(deconv2_2, act=tf.nn.relu,
shape=[3,3,512,256], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv2_3')
deconv2_4 = tl.layers.Conv2dLayer(deconv2_3, act=tf.nn.relu,
shape=[3,3,256,256], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv2_4')
deconv3 = tl.layers.DeConv2dLayer(deconv2_4, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,128,256], strides=[1,2,2,1], output_shape=[batch_size,nx/2,ny/2,128],
padding='SAME', W_init=w_init, b_init=b_init, name='devcon3_1')
deconv3_2 = tl.layers.ConcatLayer([conv4, deconv3], concat_dim=3, name='concat3_2')
deconv3_3 = tl.layers.Conv2dLayer(deconv3_2, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,256,128], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv3_3')
deconv3_4 = tl.layers.Conv2dLayer(deconv3_3, act=tf.nn.relu,
shape=[3,3,128,128], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv3_4')
deconv4 = tl.layers.DeConv2dLayer(deconv3_4, act=tf.identity, #act=tf.nn.relu,
shape=[3,3,64,128], strides=[1,2,2,1], output_shape=[batch_size,nx,ny,64],
padding='SAME', W_init=w_init, b_init=b_init, name='devconv4_1')
deconv4_2 = tl.layers.ConcatLayer([conv2, deconv4], concat_dim=3, name='concat4_2')
deconv4_3 = tl.layers.Conv2dLayer(deconv4_2, act=tf.nn.relu,
shape=[3,3,128,64], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv4_3')
deconv4_4 = tl.layers.Conv2dLayer(deconv4_3, act=tf.nn.relu,
shape=[3,3,64,64], strides=[1,1,1,1], padding='SAME',
W_init=w_init, b_init=b_init, name='conv4_4')
network = tl.layers.Conv2dLayer(deconv4_4,
act=tf.identity,
shape=[1,1,64,n_out], # [0]:foreground prob; [1]:background prob
strides=[1,1,1,1],
padding='SAME',
W_init=w_init, b_init=b_init, name='conv4_5')
# compute the softmax output
print(" * Output: %s" % network.outputs)
outputs = tl.act.pixel_wise_softmax(network.outputs)
return network, outputs