【医疗图像分割】UNETR++论文笔记及代码跑通实践

在医疗图像分割任务中,transformer模型获得了巨大的成功,UNETR提出了efficient paired attention (EPA) 模块,利用了空间和通道注意力来有效地学习通道和空间的特征,该模型在Synapse,BTCV,ACDC,BRaTs数据集上都获得了很好地效果。

论文:UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation

代码:https://github.com/Amshaker/unetr_plus_plus

tips:博主有低价AutoDL RTX4090资源,有需要可以联系。

一、论文笔记

首先看一下模型架构,整体还是UNet结构,在其中引入了提出的EPA模块。

该论文的核心就是EPA模块,EPA的提出主要是解决2个问题:

1、计算更有效率:传统的self-attention计算成本很高,对于3D的医疗图像来说更高,EPA将self-attention的K和V投影到低纬度再计算,降低了计算复杂度;

2、增强了空间和通道特征表示能力:transformer本身就是一种空间注意力机制,但是它忽略了通道特征,EPA将空间和通道特征融合在了一起。

再仔细看一下EPA的结构图,上方蓝底部分式空间注意力,下方绿底部分式通道注意力。再空间注意部分,为了降低self-attention计算量,将HWDXC的K和V降维到pXC维度。

代码如下(类中的self.EF用于降低K和V的维度,空间注意力和通道注意力的K和Q是共享的):


class EPA(nn.Module):
    """
        Efficient Paired Attention Block, based on: "Shaker et al.,
        UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
        """
    def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,
                 channel_attn_drop=0.1, spatial_attn_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))

        # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)
        self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)

        # E and F are projection matrices with shared weights used in spatial attention module to project
        # keys and values from HWD-dimension to P-dimension
        self.EF = nn.Parameter(init_(torch.zeros(input_size, proj_size)))

        self.attn_drop = nn.Dropout(channel_attn_drop)
        self.attn_drop_2 = nn.Dropout(spatial_attn_drop)

    def forward(self, x):
        B, N, C = x.shape

        qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)
        qkvv = qkvv.permute(2, 0, 3, 1, 4)
        q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]

        q_shared = q_shared.transpose(-2, -1)
        k_shared = k_shared.transpose(-2, -1)
        v_CA = v_CA.transpose(-2, -1)
        v_SA = v_SA.transpose(-2, -1)

        proj_e_f = lambda args: torch.einsum('bhdn,nk->bhdk', *args)
        k_shared_projected, v_SA_projected = map(proj_e_f, zip((k_shared, v_SA), (self.EF, self.EF)))

        q_shared = torch.nn.functional.normalize(q_shared, dim=-1)
        k_shared = torch.nn.functional.normalize(k_shared, dim=-1)

        attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature
        attn_CA = attn_CA.softmax(dim=-1)
        attn_CA = self.attn_drop(attn_CA)
        x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)

        attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2
        attn_SA = attn_SA.softmax(dim=-1)
        attn_SA = self.attn_drop_2(attn_SA)
        x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)

        return x_CA + x_SA

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'temperature', 'temperature2'}

二、代码实践

官方给出了Synapse,BTCV,ACDC,BRaTs数据集的跑通实例,我这里只跑一个BRaTs数据集,其他的是一样的步骤。

1、安装环境

使用conda安装环境:

conda create --name unetr_pp python=3.10
conda activate unetr_pp

安装torch:

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

安装依赖:

pip install -r requirements.txt

2、准备数据集

官方给出了处理好的数据集地址,直接下载即可:

数据集链接
SynapseOneDrive
ACDCOneDrive
Decathon-LungOneDrive
BRaTsOneDrive

本文下载好了BraTs数据集作为实例,将其放入以下目录:

3、训练

因为只是跑通一下,把unetr_plus_plus/unetr_pp/training/network_training/unetr_pp_trainer_tumor.py中的epoch改成10:

训练就非常简单了,进入训练集脚本目录并运行脚本:

cd training_scripts
bash run_training_tumor.sh

训练起来了:

4、评估

首先将自己训练的权重放到指定位置(原来output_tumor的unetr_pp文件夹放到unetr_plus_plus\unetr_pp\evaluation\unetr_pp_tumor_checkpoint里面去):

修改代码unetr_plus_plus/unetr_pp/inference/predict.py,共有两处:

进入评估脚本目录并运行脚本:

cd evaluation_scripts

修改run_evaluation_tumor.sh脚本,相关路径替换为自己的路径(自带的脚本我没成功,大家可以自行尝试):

#!/bin/sh

DATASET_PATH=../DATASET_Tumor

export PYTHONPATH=.././
export RESULTS_FOLDER=../unetr_pp/evaluation/unetr_pp_tumor_checkpoint
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw

# Only for Tumor, it is recommended to train unetr_plus_plus first, and then use the provided checkpoint to evaluate. It might raise issues regarding the pickle files if you evaluated without training

python /deeplearning/medicalseg/unetr_plus_plus/unetr_pp/inference/predict_simple.py -i ../DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task003_tumor/imagesTs -o ../unetr_pp/evaluation/unetr_pp_tumor_checkpoint/inferTs -m 3d_fullres  -t 3 -f 0 -chk model_final_checkpoint -tr unetr_pp_trainer_tumor


python /deeplearning/medicalseg/unetr_plus_plus/unetr_pp/inference_tumor.py 0

修改unetr_plus_plus/unetr_pp/inference_tumor.py的数据集路径,可以根据自己的情况改:

运行脚本:

bash run_evaluation_tumor.sh

在推理结果的目录unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/下多了一个

dice_five.txt文件,里面有相关精度,如下(因为就训练了10个epoch,效果不行):

本文到此结束。

评论 33
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

justld

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值