Unet网络结构源码(slim实现):输出结果为单通道,激活函数采用leakrelu
def lrelu(x):
return tf.maximum(x * 0.2, x)
activation_fn=lrelu
def upsample_and_concat(x1, x2, output_channels, in_channels):
pool_size = 2
deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])
deconv_output = tf.concat([deconv, x2], 3)
deconv_output.set_shape([None, None, None, output_channels * 2])
return deconv_output
def UNet(inputs, reg): # Unet
conv1 = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_1', weights_regularizer=reg)
conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_2',weights_regularizer=reg)
pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')
conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_1',weights_regularizer=reg)
conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_2',weights_regularizer=reg)
pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')
conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_1',weights_regularizer=reg)
conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_2',weights_regularizer=reg)
pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')
conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_1',weights_regularizer=reg)
conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_2',weights_regularizer=reg)
pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')
conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_1',weights_regularizer=reg)
conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_2',weights_regularizer=reg)
up6 = upsample_and_concat(conv5, conv4, 256, 512)
conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_1',weights_regularizer=reg)
conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_2',weights_regularizer=reg)
up7 = upsample_and_concat(conv6, conv3, 128, 256)
conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_1',weights_regularizer=reg)
conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_2',weights_regularizer=reg)
up8 = upsample_and_concat(conv7, conv2, 64, 128)
conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_1',weights_regularizer=reg)
conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_2',weights_regularizer=reg)
up9 = upsample_and_concat(conv8, conv1, 32, 64)
conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_1', weights_regularizer=reg)
conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_2',weights_regularizer=reg)
print("conv9.shape:{}".format(conv9.get_shape()))
with tf.variable_scope(name_or_scope="output"):
out = slim.conv2d(conv9, 2, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)
return out
源码地址:https://gitee.com/MengNiMeia/myunet
损失函数采用softmax分类损失或者L2/L1损失,分类损失输出为2值图,L2/L1损失输出为灰度图。