基于加权的Res-UNet注意力机制分割视网膜血管分割挑战

针对视网膜血管分割的挑战,文章提出了加权Res-UNet模型,该模型基于UNet并引入加权注意力机制,改善小血管缺失、视盘区分割差和结构关系维持的问题。通过预处理增强图像对比度,使用CLAHE、resize和数据增强。模型结构包括加权注意、跳跃连接和二值交叉熵损失函数。在测试阶段,采用patch重组获取完整分割结果。
摘要由CSDN通过智能技术生成

•小血管缺失:位于树枝末端的小血管有时甚至人眼也难以分辨。
•视盘区分割差:视盘区往往较亮,对比度较低,使得视网膜血管难以分割。
•无法维持结构关系:视网膜血管有类似于树木的分叉结构。但当血管太薄而无法探测时,这种结构很难维持。
•照明:光照不足或过度暴露,包括相机光源引起的光反射,会降低图像对比度,导致视网膜血管边界不清晰。

为了解决这些挑战,本文提出:加权的Res-UNet。模型建立在原始的UNet模型基础上,并增加了一个加权注意机制。使得模型学习更多的鉴别血管和非血管像素的特征,更好的维护视网膜血管树结构。

架构

 

方法

A. 眼底图像预处理

原始眼底图像对比度低,经过预处理后的图像会使CNN有更好的性能。

使用对比度限制自适应直方图均衡化(CLAHE)操作作为预处理步骤,来提高图像对比度。

将每个原始眼底图片resize为512*512,并变为灰度图,再使用CLAHE方法标准化灰度图。

由于训练图像的数量非常有限,我们在每张训练图像中随机抽取了500个重叠64*64的patch。此外,我们在每个提取的patch中都采用了广泛使用的数据扩充操作,如水平翻转、宽度移位范围等。这样,训练图像的数量就增加了500多倍。

B.Res-UNet的体系结构

1)加权注意机制:DRIVE和STARE中的眼底图像具有圆形感兴趣区域(ROI)和深色背景。然后使用圆形模板感兴趣区域蒙版M作为加权注意,即图2所示的黄色箭头。

注意机制是通过将模型的最后一层的特征图与注意掩模相乘来实现的。
利用这种加权注意机制,我们的模型将只关注目标ROI区域,抛弃无关的噪声背景。对于DRIVE数据集,我们直接使用提供的眼底区域掩码作为加权注意掩码。而对于STARE数据集,我们通过一个简单的处理步骤计算出了注意掩码,将眼底图像转换成灰度,应用高斯滤波,然后在值40处进行二值化阈值分割眼底区域。

2)跳跃连接方案:如U-Net所示,增加跳跃连接可以增加深度,提高深度CNNs的准确性。受此启发,我们还将skip连接添加到模型中,如图2中的纯灰色箭头所示。对于每个卷积块,跳跃链接的公式为:y = F(x, {wi}) + H(x),  F包含两个卷积运算和一个max-pooling或一个up-sampling运算,H要么是相同的映射,要么是卷积运算,使输入的特征维数与F相同。

3)损失函数:为了训练我们提出的模型,我们选择binary cross entropy作为分割损失函数

C. patch重组

在测试阶段,我们没有像在训练阶段那样做随机的重叠patch裁剪,我们只是把512×512的输入平铺到8×8个大小为64×64的不重叠patch上。这样,在得到每个分割块的预测后,我们可以根据每个分割块的位置重新分组得到整个眼底图像的分割结果。

 

图2 加权Res-UNet的总体结构

这里我提供一种增加注意力机制的unet模型


from keras.models import *
from keras.layers import *
from keras.optimizers import Adam, SGD
from keras import backend as K


class Unet:
    def __init__(self, pretrained_weights=None, input_size=(160, 160, 1)):
        self.model = self.unet(pretrained_weights, input_size)


    def summary(self):
        self.model.summary()


    def channel_attention(self, input_feature, ratio=8):
        channel = input_feature._keras_shape[-1]

        shared_layer_one = Dense(channel // ratio,
                                 activation='relu',
                                 kernel_initializer='he_normal',
                                 use_bias=True,
                                 bias_initializer='zeros')
        shared_layer_two = Dense(channel,
                                 kernel_initializer='he_normal',
                                 use_bias=True,
                                 bias_initializer='zeros')

        avg_pool = GlobalAveragePooling2D()(input_feature)
        avg_pool = Reshape((1, 1, channel))(avg_pool)
        avg_pool = shared_layer_one(avg_pool)
        avg_pool = shared_layer_two(avg_pool)

        max_pool = GlobalMaxPooling2D()(input_feature)
        max_pool = Reshape((1, 1, channel))(max_pool)
        max_pool = shared_layer_one(max_pool)
        max_pool = shared_layer_two(max_pool)

        cbam_feature = Add()([avg_pool, max_pool])
        cbam_feature = Activation('sigmoid')(cbam_feature)
        return multiply([input_feature, cbam_feature])


    def spatial_attention(self, input_feature, kernel_size=7):
        avg_pool = Lambda(lambda x: K.mean(x, axis=-1, keepdims=True))(input_feature)
        max_pool = Lambda(lambda x: K.max(x, axis=-1, keepdims=True))(input_feature)
        concat = Concatenate(axis=-1)([avg_pool, max_pool])
        cbam_feature = Conv2D(filters=1,
                              kernel_size=kernel_size,
                              strides=1,
                              padding='same',
                              activation='sigmoid',
                              kernel_initializer='he_normal',
                              use_bias=False)(concat)
        return multiply([input_feature, cbam_feature])


    def cbam_block(self, cbam_feature, ratio=2):
        # https://github.com/kobiso/CBAM-keras/blob/master/models/attention_module.py
        cbam_feature = self.channel_attention(cbam_feature, ratio)
        cbam_feature = self.spatial_attention(cbam_feature)
        return cbam_feature


    def unet(self, pretrained_weights=None, input_size=(128, 128, 1)):
        def dice_coef(y_true, y_pred, smooth=1):
            intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
            return (2. * intersection + smooth) / (
                    K.sum(K.square(y_true), -1) + K.sum(K.square(y_pred), -1) + smooth)

        def dice_coef_loss(y_true, y_pred):
            return 1 - dice_coef(y_true, y_pred)

        inputs = Input(input_size)
        
        d1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
        d1 = BatchNormalization()(d1)
        d1 = Conv2D(32, 3, activation='relu', padding='same')(d1)
        d1 = BatchNormalization()(d1)
        d1 = SpatialDropout2D(0.1)(d1)
        d1 = self.spatial_attention(d1)

        d2 = MaxPooling2D(pool_size=(2, 2))(d1)
        d2 = Conv2D(64, 3, activation='relu', padding='same')(d2)
        d2 = BatchNormalization()(d2)
        d2 = Conv2D(64, 3, activation='relu', padding='same')(d2)
        d2 = BatchNormalization()(d2)
        d2 = SpatialDropout2D(0.1)(d2)
        d2 = self.spatial_attention(d2)

        d3 = MaxPooling2D(pool_size=(2, 2))(d2)
        d3 = Conv2D(128, 3, activation='relu', padding='same')(d3)
        d3 = BatchNormalization()(d3)
        d3 = Conv2D(128, 3, activation='relu', padding='same')(d3)
        d3 = BatchNormalization()(d3)
        d3 = SpatialDropout2D(0.25)(d3)
        d3 = self.spatial_attention(d3)

        d4 = MaxPooling2D(pool_size=(2, 2))(d3)
        d4 = Conv2D(256, 3, activation='relu', padding='same')(d4)
        d4 = BatchNormalization()(d4)
        d4 = Conv2D(256, 3, activation='relu', padding='same')(d4)
        d4 = BatchNormalization()(d4)
        d4 = SpatialDropout2D(0.4)(d4)
        d4 = self.spatial_attention(d4)

        u3 = UpSampling2D(size=(2, 2))(d4)
        u3 = Conv2D(128, 2, activation='relu', padding='same')(u3)
        u3 = BatchNormalization()(u3)
        u3 = concatenate([d3, u3], axis=-1)
        u3 = Conv2D(128, 3, activation='relu', padding='same')(u3)
        u3 = BatchNormalization()(u3)
        u3 = Conv2D(128, 3, activation='relu', padding='same')(u3)
        u3 = BatchNormalization()(u3)
        u3 = self.spatial_attention(u3)

        u2 = UpSampling2D(size=(2, 2))(u3)
        u2 = Conv2D(64, 2, activation='relu', padding='same')(u2)
        u2 = BatchNormalization()(u2)
        u2 = concatenate([d2, u2], axis=-1)
        u2 = Conv2D(64, 3, activation='relu', padding='same')(u2)
        u2 = BatchNormalization()(u2)
        u2 = Conv2D(64, 3, activation='relu', padding='same')(u2)
        u2 = BatchNormalization()(u2)
        u2 = self.spatial_attention(u2)

        u1 = UpSampling2D(size=(2, 2))(u2)
        u1 = Conv2D(32, 2, activation='relu', padding='same')(u1)
        u1 = BatchNormalization()(u1)
        u1 = concatenate([d1, u1], axis=-1)
        u1 = Conv2D(32, 3, activation='relu', padding='same')(u1)
        u1 = BatchNormalization()(u1)
        u1 = Conv2D(32, 3, activation='relu', padding='same')(u1)
        u1 = BatchNormalization()(u1)
        u1 = self.spatial_attention(u1)

        out = Conv2D(3, 3, activation='relu', padding='same')(u1)
        out = BatchNormalization()(out)
        out = self.spatial_attention(out)

        out = Conv2D(1, 1, activation='sigmoid')(out)

        model = Model(inputs=inputs, outputs=out)

        # optimizer = Adam(lr=0.001)
        optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
        model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["acc", dice_coef])

        if pretrained_weights is not None:
            model.load_weights(pretrained_weights)

        return model


    def get_model(self):
        return self.model

结果

 

image.png

参考链接:
基于U-Net+残差网络的语义分割缺陷检测
Keras 使用Residual-Block 加深U-net网络的深度
U-net与ResNet结合
基于Resnet+Unet的图像分割模型(by Pytorch)
U-Net 和 ResNet:长短跳跃连接的重要性(生物医学图像分割)
作者:zelda2333
链接:https://www.jianshu.com/p/5f92303f6a9c
 

评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值