unet网络结构说明及keras实现详解

原文链接:https://blog.csdn.net/weixin_38193906/article/details/83787569?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

论文地址:

下载链接

背景介绍

unet网络常见于图像分割任务,本文从其网络结构出发,详细解释unet网络结构的实现过程。

网络结构

unet网络结构

网络结构说明

unet网络可以简单看为先下采样,经过不同程度的卷积,学习了深层次的特征,在经过上采样回复为原图大小,上采样用反卷积实现。最后输出类别数量的特征图,如分割是两类(是或不是),典型unet也是输出两张图,最后要说明一下,原网络到此就结束了,其实在最后还要使用激活函数softmax将这两个类别转换为概率图,针对某个像素点,如输出是[0.1,0.9],则判定这个像素点是第二类的概率更大。
网络结构可以看成3个部分:

    1. 下采样:网络的红色箭头部分,池化实现
    1. 上采样:网络的绿色箭头部分,反卷积实现
    1. 最后层的softmax:在网络结构中,最后输出两张fiture maps后,其实在最后还要做一次softmax,将其转换为概率图。

keras实现详解

from __future__ import division, print_function
from keras.layers import Input, Conv2D, Conv2DTranspose
from keras.layers import MaxPooling2D, Cropping2D, Concatenate
from keras.layers import Lambda, Activation, BatchNormalization, Dropout
from keras.models import Model
from keras import backend as K
def downsampling_block(input_tensor, filters, padding='valid',  #下采样部分
                       batchnorm=False, dropout=0.0):
    _, height, width, _ = K.int_shape(input_tensor)
    assert height % 2 == 0
    assert width % 2 == 0
    x = Conv2D(filters, kernel_size=(3,3), padding=padding)(input_tensor)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x
    x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x
    return MaxPooling2D(pool_size=(2,2))(x), x   #返回的是池化后的值和dropout后的值,这里dropout后的值用于上采样特征级联
def upsampling_block(input_tensor, skip_tensor, filters, padding='valid',
                     batchnorm=False, dropout=0.0):    #下采样部分
    x = Conv2DTranspose(filters, kernel_size=(2,2), strides=(2,2))(input_tensor)
    _, x_height, x_width, _ = K.int_shape(x)
    _, s_height, s_width, _ = K.int_shape(skip_tensor)
    h_crop = s_height - x_height
    w_crop = s_width - x_width
    assert h_crop >= 0
    assert w_crop >= 0
    if h_crop == 0 and w_crop == 0:
        y = skip_tensor
    else:                       #使级联时像素大小一致
        cropping = ((h_crop//2, h_crop - h_crop//2), (w_crop//2, w_crop - w_crop//2))
        y = Cropping2D(cropping=cropping)(skip_tensor)
    x = Concatenate()([x, y])         #特征级联
    x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x
    x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x
    return x                   #返回dropout后的值
def unet(height, width, channels, classes, features=64, depth=4,
         temperature=1.0, padding='valid', batchnorm=False, dropout=0.0):  #使用4个深度长的网络就是官网的典型网络
x = Input(shape=(height, width, channels))
    inputs = x
    skips = []                   #用于存放下采样中,每个深度后,dropout后的值,以供之后级联使用
    for i in range(depth):
        x, x0 = downsampling_block(x, features, padding,
                                   batchnorm, dropout)
        skips.append(x0)
        features *= 2            #下采样过程中,每个深度往下,特征翻倍,即每次使用翻倍数目的滤波器
    x = Conv2D(filters=features, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x
    x = Conv2D(filters=features, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x
    for i in reversed(range(depth)):    #下采样过程中,深度从深到浅
        features //= 2                  #每个深度往上。特征减少一倍
        x = upsampling_block(x, skips[i], features, padding,
                             batchnorm, dropout)
    x = Conv2D(filters=classes, kernel_size=(1,1))(x)
    logits = Lambda(lambda z: z/temperature)(x)      #简单的对x做一个变换
    probabilities = Activation('softmax')(logits)    #对输出的两类做softmax,转换为概率。形式如【0.1,0.9],则预测为第二类的概率更大。
    return Model(inputs=inputs, outputs=probabilities)
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值