复现Hermes

文章:Training Like a Medical Resident: Universal Medical Image Segmentation via Context Prior Learning
2024 CVPR

1.环境安装

先新建一个环境吧

conda create -n hermes python=3.9

1.1 pip install -r requirements.txt


在这里插入图片描述
解决方法:去掉 fonttools

安装apex

根据官方指引安装
APEX介绍以及安装:英伟达(NVIDIA)训练深度学习模型神器APEX使用指南
我用的官方网站的安装

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./

数据预处理

我的的自己的数据,总体步骤分两步:
①python dataset_conversion/xxx_3d.py
主要内容:重采样,裁剪(保留前景部分)
得到dataset.yaml
②python nii2npy.py
75,76行缺少一个括号,在逗号面前加
主要内容:HU值截断
CT:[-991,500]
MR:percentile_2 = np.percentile(img, 2, axis=None)
percentile_98 = np.percentile(img, 98, axis=None)
主要内容:归一化
保存为npy文件
!!!!!!要先确定好自己的patch_size!!!!!!!!!,这里默认pad到[128,128,128),我需要192,然后发现报错。。。。。

训练

①修改参数
config/universal/hermes_resunet_3d.yaml
我用自己的数据,主要修改:

data_root
tn: 13  # the number of task priors + one null task token
mn: 3  # the number of modality priors
training_size
dataset_name_lis
dataset_classes_list #去掉背景
window_size #与training_size一致

training/dataset/dim3/dataset_config.py
根据需求修改

train_test_split#数据集划分
dataset_lab_map#任务编码
dataset_modality_map#模态编码
dataset_sample_weight#采样权重,不知道作者怎么设置的权重,作者说小数据集有高权重,我靠感觉设置的自己数据集的权重
dataset_aug_prob #数据增强参数

其中的dataset_lab_map是任务编码,不是分割目标编码哦,和Uniseg不同,这里是将一个分割目标视作为一个任务,而Uniseg是一个数据集一个任务。采样也和Uniseg不同,Uniseg是一个Batch_size同一个任务,而这里是混合的多个任务。

直接运行,没有报错

python train.py --batch_size 8 --gpu 0,1,2,3

查看tensorboard
按ctrl+shift+p,输入并选择 TensorBoard: Start TensorBoard,提供日志目录(log/universal/test),即可。

测试

我需要可视化结果,所以自己写了一个,可能有错哦

import os
import argparse
import yaml
from training.dataset.utils import get_dataset
from torch.utils import data
from model.utils import get_model
import numpy as np
import torch
from inference.utils import get_inference
from tqdm import tqdm
import SimpleITK as sitk

parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str,
                    default='universal', help='model_name')
parser.add_argument('--gpu', type=str, default='4', help='GPU to use')
parser.add_argument('--model_path', type=str, default='./exp', help='log path')
parser.add_argument('--unique_name', type=str, default='hermes_dualattention', help='unique experiment name')
parser.add_argument('--result_path', type=str, default='./result', help='log path')
parser.add_argument('--prefix', type=str, default='test', help='prefix')
parser.add_argument('--dimension', type=str, default='3d', help='2d model or 3d model')
parser.add_argument('--dataset', type=str, default='universal', help='dataset name')

if __name__ == '__main__':
    # parse the arguments
    args = parser.parse_args()

    config_path = 'config/universal/hermes_resunet_3d.yaml'
    if not os.path.exists(config_path):
        raise ValueError("The specified configuration doesn't exist: %s"%config_path)
    with open(config_path, 'r') as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
    for key, value in config.items():
        setattr(args, key, value)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    net = get_model(args)
    model_folder = "{}/{}/{}".format(args.model_path, args.model, args.unique_name)
    save_mode_path = os.path.join(model_folder, 'best.pth')
    # 加载模型检查点
    checkpoint = torch.load(save_mode_path)
    net.load_state_dict(checkpoint['model_state_dict'])
    inference = get_inference(args)

    testLoader_list = []

    for dataset_name in args.dataset_name_list:  #加载各个数据集
        print("task",dataset_name)
        save_path = os.path.join(args.result_path,dataset_name)

        if not os.path.exists(save_path):
            os.makedirs(save_path)
        testset = get_dataset(args, dataset_name_list=[dataset_name], mode='test')
        print('length of testset:', len(testset))
        testLoader = data.DataLoader(
            testset,
            batch_size=1,  # has to be 1 sample per gpu, as the input size of 3D input is different
            shuffle=False, 
            sampler=None,
            pin_memory=True,
            num_workers=0
        )
        iterator = tqdm(testLoader)
        for (images, _, tgt_idx, mod_idx, spacing,img_name) in iterator:
            images= images.cuda().float()
            tgt_idx = tgt_idx.cuda().long()
            mod_idx = mod_idx.cuda().long().unsqueeze(1)
            C = torch.nonzero(tgt_idx.squeeze(0)+1).shape[0]  #实际类别数
            pred = inference(net, images, tgt_idx, mod_idx, args)#返回[1,classes,D,H,W]
            pred = pred.to(torch.int8)
            pred = pred[0, :, :, :, :]
            pred = pred[:C, :, :, :]
            pred = torch.argmax(pred, dim=0)          
            torch.cuda.empty_cache()
            pred_np = pred.cpu().numpy().astype(np.int8)
            print(f"Prediction shape: {pred_np.shape}")
            print(f"Prediction unique values: {np.unique(pred_np)}")
            sitk_label = sitk.GetImageFromArray(pred_np)
            spacing_list = spacing.squeeze(0).tolist()
            print(f"Spacing: {spacing_list}")
            sitk_label.SetSpacing(spacing_list)

            output_filename = os.path.splitext(os.path.basename(img_name[0]))[0]
            save_label_path = os.path.join(save_path, '{}.nii.gz'.format(output_filename))

            sitk.WriteImage(sitk_label, save_label_path)
            loaded_label = sitk.ReadImage(save_label_path)
            print(f"Loaded label shape: {sitk.GetArrayFromImage(loaded_label).shape}")









需要修改training/dataset/dim3/dataset_universal.py 里的class UniversalDataset(Dataset),修改为:

        if self.mode == 'train':
            return tensor_img, tensor_lab.to(torch.int8), torch.from_numpy(tgt).to(tensor_img.device), torch.from_numpy(np.array(self.mod_list[idx])).to(tensor_img.device)
        elif self.mode == 'val':
            return tensor_img, temp_lab, torch.from_numpy(tgt), torch.from_numpy(np.array(self.mod_list[idx])), np.array(self.spacing_list[idx])
        else:
            return tensor_img, temp_lab, torch.from_numpy(tgt), torch.from_numpy(np.array(self.mod_list[idx])), np.array(self.spacing_list[idx]),self.img_list[idx]

部分代码解读

有点乱
以CNN为框架的话,会进行五次融合,分别是下采样最后两层,瓶颈层,以及上采样最开始两层。
采用交叉注意力机制进行融合。
任务以及模态的prompt使用可学习参数,其中任务prompt大小为[total_tasknum+1,c],模态为[total_modalitynum,10,c],其中10是固定的。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值