论文地址:
背景介绍
unet网络常见于图像分割任务,本文从其网络结构出发,详细解释unet网络结构的实现过程。
网络结构
网络结构说明
unet网络可以简单看为先下采样,经过不同程度的卷积,学习了深层次的特征,在经过上采样回复为原图大小,上采样用反卷积实现。最后输出类别数量的特征图,如分割是两类(是或不是),典型unet也是输出两张图,最后要说明一下,原网络到此就结束了,其实在最后还要使用激活函数softmax将这两个类别转换为概率图,针对某个像素点,如输出是[0.1,0.9],则判定这个像素点是第二类的概率更大。
网络结构可以看成3个部分:
-
- 下采样:网络的红色箭头部分,池化实现
-
- 上采样:网络的绿色箭头部分,反卷积实现
-
- 最后层的softmax:在网络结构中,最后输出两张fiture maps后,其实在最后还要做一次softmax,将其转换为概率图。
keras实现详解
from __future__ import division, print_function
from keras.layers import Input, Conv2D, Conv2DTranspose
from keras.layers import MaxPooling2D, Cropping2D, Concatenate
from keras.layers import Lambda, Activation, BatchNormalization, Dropout
from keras.models import Model
from keras import backend as K
def downsampling_block(input_tensor, filters, padding='valid', #下采样部分
batchnorm=False, dropout=0.0):
_, height, width, _ = K.int_shape(input_tensor)
assert height % 2 == 0
assert width % 2 == 0
x = Conv2D(filters, kernel_size=(3,3), padding=padding)(input_tensor)
x = BatchNormalization()(x) if batchnorm else x
x = Activation('relu')(x)
x = Dropout(dropout)(x) if dropout > 0 else x
x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
x = BatchNormalization()(x) if batchnorm else x
x = Activation('relu')(x)
x = Dropout(dropout)(x) if dropout > 0 else x
return MaxPooling2D(pool_size=(2,2))(x), x #返回的是池化后的值和dropout后的值,这里dropout后的值用于上采样特征级联
def upsampling_block(input_tensor, skip_tensor, filters, padding='valid',
batchnorm=False, dropout=0.0): #下采样部分
x = Conv2DTranspose(filters, kernel_size=(2,2), strides=(2,2))(input_tensor)
_, x_height, x_width, _ = K.int_shape(x)
_, s_height, s_width, _ = K.int_shape(skip_tensor)
h_crop = s_height - x_height
w_crop = s_width - x_width
assert h_crop >= 0
assert w_crop >= 0
if h_crop == 0 and w_crop == 0:
y = skip_tensor
else: #使级联时像素大小一致
cropping = ((h_crop//2, h_crop - h_crop//2), (w_crop//2, w_crop - w_crop//2))
y = Cropping2D(cropping=cropping)(skip_tensor)
x = Concatenate()([x, y]) #特征级联
x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
x = BatchNormalization()(x) if batchnorm else x
x = Activation('relu')(x)
x = Dropout(dropout)(x) if dropout > 0 else x
x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
x = BatchNormalization()(x) if batchnorm else x
x = Activation('relu')(x)
x = Dropout(dropout)(x) if dropout > 0 else x
return x #返回dropout后的值
def unet(height, width, channels, classes, features=64, depth=4,
temperature=1.0, padding='valid', batchnorm=False, dropout=0.0): #使用4个深度长的网络就是官网的典型网络
x = Input(shape=(height, width, channels))
inputs = x
skips = [] #用于存放下采样中,每个深度后,dropout后的值,以供之后级联使用
for i in range(depth):
x, x0 = downsampling_block(x, features, padding,
batchnorm, dropout)
skips.append(x0)
features *= 2 #下采样过程中,每个深度往下,特征翻倍,即每次使用翻倍数目的滤波器
x = Conv2D(filters=features, kernel_size=(3,3), padding=padding)(x)
x = BatchNormalization()(x) if batchnorm else x
x = Activation('relu')(x)
x = Dropout(dropout)(x) if dropout > 0 else x
x = Conv2D(filters=features, kernel_size=(3,3), padding=padding)(x)
x = BatchNormalization()(x) if batchnorm else x
x = Activation('relu')(x)
x = Dropout(dropout)(x) if dropout > 0 else x
for i in reversed(range(depth)): #下采样过程中,深度从深到浅
features //= 2 #每个深度往上。特征减少一倍
x = upsampling_block(x, skips[i], features, padding,
batchnorm, dropout)
x = Conv2D(filters=classes, kernel_size=(1,1))(x)
logits = Lambda(lambda z: z/temperature)(x) #简单的对x做一个变换
probabilities = Activation('softmax')(logits) #对输出的两类做softmax,转换为概率。形式如【0.1,0.9],则预测为第二类的概率更大。
return Model(inputs=inputs, outputs=probabilities)