Unet模型
今天给大家简单介绍一下Unet网络。
网络结构
Unet论文
Unet是2015年提出一种语义分割模型,主要用于医学领域的图像分割问题,因其网络结构呈现一个U型,故名为U-Net。网络结构如下图所示:
网络结构说明
这是一种对称的结构。首先通过卷积池化进行特征提取,然后经过上采样进行重构。
从这个网络中可以看到,输入图像大小为572x572,输出图像大小却是388x388,输出比输入要小,这主要是满足医学领域分割的需要,提高分割的精度。
其主要采用3x3的卷积核和relu激活函数进行特征提取,使用2x2的最大值池化进行下采样进行压缩。上采样时,使用了邻近插值,没有使用转置卷积。在每个上采样后进行跳跃连接,使用了concat操作,将特征在channel维度拼接在一起,形成更厚的特征,将全局特征和局部特征进行结合,而不是简单的相加。随后使用卷层进行channel维度压缩。跳跃连接时,将大图裁小 ,这样虽然会丢失一些局部的信息,不过没关系,因为它只是对全局的补充,辅助位置矫正。
下面简单总结一下Unet的特点:
- U-Net是完全对称的,且对解码器进行了加卷积加深处理。
- 上采样的时候,使用了邻近插值,没有使用转置卷积。
- 跳跃连接使用了concat操作,将特征在channel维度拼接在一起,而不是简单的相加。
- 全程使用valid进行卷积(包括pooling),这样的话可以保证分割的结果都是基于没有缺失的上下文特征得到的,因此输入输出的图像尺寸不太一样。
网络实现
下面给出了Unet网络的keras实现:
def conv_block(inputs, filters):
conv1 = Conv2D(filters=filters, kernel_size=3, activation='relu', padding='same',
kernel_initializer='he_normal')(inputs)
#bn1 = BatchNormalization()(conv1)
conv2 = Conv2D(filters=filters, kernel_size=3, activation='relu', padding='same',
kernel_initializer='he_normal')(conv1)
#bn2 = BatchNormalization()(conv2)
return conv2 #bn2
def upsampling(inputs, filters):
up = UpSampling2D(size=(2, 2))(inputs)
conv = Conv2D(filters=filters, kernel_size=2, activation='relu', padding='same',
kernel_initializer='he_normal')(up)
return conv
##############################
#Uet #
##############################
def U_net(input_size=(224, 224, 1), n_class=2, filters=(64, 128, 256, 512, 1024), re_shape=False):
inputs = Input(shape=input_size)
conv1 = conv_block(inputs=inputs, filters=filters[0])
pool1 = MaxPool2D(pool_size=(2, 2))(conv1)
conv2 = conv_block(inputs=pool1, filters=filters[1])
pool2 = MaxPool2D(pool_size=(2, 2))(conv2)
conv3 = conv_block(inputs=pool2, filters=filters[2])
pool3 = MaxPool2D(pool_size=(2, 2))(conv3)
conv4 = conv_block(inputs=pool3, filters=filters[3])
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPool2D(pool_size=(2, 2))(drop4)
conv5 = conv_block(inputs=pool4, filters=filters[4])
drop5 = Dropout(0.5)(conv5)
up1 = upsampling(inputs=drop5, filters=filters[3])
concat1 = concatenate([drop4, up1], axis=3)
conv6 = conv_block(inputs=concat1, filters=filters[3])
up2 = upsampling(inputs=conv6, filters=filters[2])
concat2 = concatenate([conv3, up2], axis=3)
conv7 = conv_block(inputs=concat2, filters=filters[2])
up3 = upsampling(inputs=conv7, filters=filters[1])
concat3 = concatenate([conv2, up3], axis=3)
conv8 = conv_block(inputs=concat3, filters=filters[1])
up4 = upsampling(inputs=conv8, filters=filters[0])
concat4 = concatenate([conv1, up4], axis=3)
conv9 = conv_block(inputs=concat4, filters=filters[0])
conv10 = Conv2D(filters=n_class, kernel_size=1, padding='same',
kernel_initializer='he_normal')(conv9)
if re_shape==True:
conv10 = Reshape((input_size[0]*input_size[1], n_class))(conv10)
out =Activation('sigmoid')(conv10)
model = Model(input=inputs, output=out)
model.summary()
model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
return model
这里使用keras框架实现Unet网络,padding方式选择same,输入输出图像大小保持一致。可以通过修改input_size,n_class,filters来适应自己的数据。