本篇博客主要使用tensorflow2实现UNet。
参考博客
从图中可以看到,UNet主要有下采样和上采样部分;在向上采样的过程中,会用到下采样过程的特征图。
图中
28
4
2
284^2
2842之类的表示图面积。
以下为UNet实现(tensorflow)
import tensorflow as tf
import cv2
from tensorflow import keras
import numpy as np
from tensorflow.keras.layers import Cropping2D, Concatenate, BatchNormalization, Activation, Softmax, Dropout
input_h = 572
input_w = 572
down_feature_list = []
#save down feature
filter_list = [2**i for i in range(6,10)]
def down_sampling(inputs, filters):
x = keras.layers.Conv2D(filters, [3,3])(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = keras.layers.Conv2D(filters, [3,3])(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
down_feature_list.append(x)
return keras.layers.MaxPool2D()(x)
def up_sampling(inputs, down_data, filters):
now_data = keras.layers.Conv2DTranspose(filters, 2, strides=2, padding='valid')(inputs)
a_h, a_w = now_data.shape[1:3]
b_h, b_w = down_data.shape[1:3]
h_delta = b_h - a_h
w_delta = b_w - a_w
cropping = ((h_delta//2, h_delta//2), (w_delta//2, w_delta//2))
crop_data = Cropping2D(cropping)(down_data)
concat_data = Concatenate()([crop_data, now_data])
out_data = keras.layers.Conv2D(filters, [3,3])(concat_data)
out_data = keras.layers.Conv2D(filters, [3,3])(out_data)
return out_data
def Unet():
inputs = keras.Input(shape=(input_h, input_w, 3), name="input")
layer = inputs
for filters in filter_list:
layer = down_sampling(layer, filters)
for filter_num in [1024, 512]:
layer = keras.layers.Conv2D(filter_num, [3,3])(layer)
layer = BatchNormalization()(layer)
layer = Activation('relu')(layer)
for filters in filter_list[::-1]:
down_feature = down_feature_list.pop()
layer = up_sampling(layer, down_feature, filters)
layer = keras.layers.Conv2D(2, 1, padding='valid')(layer)
probabilities = Softmax()(layer)
model = tf.keras.models.Model(inputs, probabilities)
return model
unet = Unet()