1 Unet结构
a. Unet结构
input
↓ Conv(f=64)(512,512,64)
Conv(f=64)(512,512,64) Conv(f=64)(512,512,64)
Conv(f=64)(512,512,64) → f1 → Concatenate(512,512,192)
MaxPooling(s=2)(256,256,64) upsampling(512,512,128)
↓ Conv(f=128)(256,256,128)
Conv(f=128)(256,256,128) Conv(f=128)(256,256,128)
Conv(f=128)(256,256,128) → f2 → Concatenate(256,256,384)
MaxPooling(s=2)(128,128,128) ↑
↓ upsampling(256,256,256)
Conv(f=256)(128,128,256) Conv(f=256)(128,128,256)
Conv(f=256)(128,128,256) Conv(f=256)(128,128,256)
Conv(f=256)(128,128,256) → f3 → Concatenate(128,128,768)
MaxPooling(s=2)(64,64,256) ↑
↓ upsampling(128,128,512)
Conv(f=512)(64,64,512) Conv(f=512)(64,64,512)
Conv(f=512)(64,64,512) Conv(f=512)(64,64,512)
Conv(f=512)(64,64,512) → f4 → Concatenate(64,64,1024)
MaxPooling(s=2)(32,32,512) ↑
↓ ↑
Conv(f=512)(32,32,512) upsampling(64,64,512)
Conv(f=512)(32,32,512) ↑
Conv(f=512)(32,32,512) → f5 →
b. 代码
import numpy as np
from keras.models import *
from keras.layers import *
from nets.vgg16 import VGG16
def Unet(input_shape=(256,256,3), num_classes=21):
inputs = Input(input_shape)
feat1, feat2, feat3, feat4, feat5 = VGG16(inputs)
channels = [64, 128, 256, 512]
P5_up = UpSampling2D(size=(2, 2))(feat5)
P4 = Concatenate(axis=3)([feat4, P5_up])
P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer='he_normal')(P4)
P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer='he_normal')(P4)
P4_up = UpSampling2D(size=(2, 2))(P4)
P3 = Concatenate(axis=3)([feat3, P4_up])
P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer='he_normal')(P3)
P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer='he_normal')(P3)
P3_up = UpSampling2D(size=(2, 2))(P3)
P2 = Concatenate(axis=3)([feat2, P3_up])
P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer='he_normal')(P2)
P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer='he_normal')(P2)
P2_up = UpSampling2D(size=(2, 2))(P2)
P1 = Concatenate(axis=3)([feat1, P2_up])
P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer='he_normal')(P1)
P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer='he_normal')(P1)
P1 = Conv2D(num_classes, 1, activation="softmax")(P1)
model = Model(inputs=inputs, outputs=P1)
return model
2 损失
a. 损失函数
# 1.Cross Entropy Loss。
# 2.Dice Loss = 1 - Dice
Dice = 2(x*y)/(|x|+|y|)
b. 代码
def dice_loss_with_CE(beta=1, smooth = 1e-5):
def _dice_loss_with_CE(y_true, y_pred):
y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
CE_loss = - y_true[...,:-1] * K.log(y_pred)
CE_loss = K.mean(K.sum(CE_loss, axis = -1))
tp = K.sum(y_true[...,:-1] * y_pred, axis=[0,1,2])
fp = K.sum(y_pred , axis=[0,1,2]) - tp
fn = K.sum(y_true[...,:-1], axis=[0,1,2]) - tp
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
score = tf.reduce_mean(score)
dice_loss = 1 - score
# dice_loss = tf.Print(dice_loss, [dice_loss, CE_loss])
return CE_loss + dice_loss
return _dice_loss_with_CE
def CE():
def _CE(y_true, y_pred):
y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
CE_loss = - y_true[...,:-1] * K.log(y_pred)
CE_loss = K.mean(K.sum(CE_loss, axis = -1))
# dice_loss = tf.Print(CE_loss, [CE_loss])
return CE_loss
return _CE