UNet的tensorflow2实现

  本篇博客主要使用tensorflow2实现UNet。
参考博客
在这里插入图片描述
  从图中可以看到,UNet主要有下采样和上采样部分;在向上采样的过程中,会用到下采样过程的特征图。
图中 28 4 2 284^2 2842之类的表示图面积。

以下为UNet实现(tensorflow)

import tensorflow as tf
import cv2
from tensorflow import keras 
import numpy as np
from tensorflow.keras.layers import Cropping2D, Concatenate, BatchNormalization, Activation, Softmax, Dropout

input_h = 572
input_w = 572
down_feature_list = []
#save down feature
filter_list = [2**i for i in range(6,10)]

def down_sampling(inputs, filters):
    x = keras.layers.Conv2D(filters, [3,3])(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = keras.layers.Conv2D(filters, [3,3])(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    down_feature_list.append(x)
    return keras.layers.MaxPool2D()(x)

def up_sampling(inputs, down_data, filters):
    now_data = keras.layers.Conv2DTranspose(filters, 2, strides=2, padding='valid')(inputs)
    a_h, a_w = now_data.shape[1:3]
    b_h, b_w = down_data.shape[1:3]
    h_delta = b_h - a_h
    w_delta = b_w - a_w
    cropping = ((h_delta//2, h_delta//2), (w_delta//2, w_delta//2))
    crop_data = Cropping2D(cropping)(down_data)
    concat_data = Concatenate()([crop_data, now_data])

    out_data = keras.layers.Conv2D(filters, [3,3])(concat_data)
    out_data = keras.layers.Conv2D(filters, [3,3])(out_data)
    return out_data
    
def Unet():
    inputs = keras.Input(shape=(input_h, input_w, 3), name="input")
    layer = inputs
    for filters in filter_list:
        layer = down_sampling(layer, filters)

    for filter_num in [1024, 512]:
        layer = keras.layers.Conv2D(filter_num, [3,3])(layer)
        layer = BatchNormalization()(layer)
        layer = Activation('relu')(layer)

    for filters in filter_list[::-1]:
        down_feature = down_feature_list.pop()
        layer = up_sampling(layer, down_feature, filters)
        
    layer = keras.layers.Conv2D(2, 1, padding='valid')(layer)
    probabilities  = Softmax()(layer)
    model = tf.keras.models.Model(inputs, probabilities)
    return model

unet = Unet()
  • 2
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值