文章:Training Like a Medical Resident: Universal Medical Image Segmentation via Context Prior Learning
2024 CVPR
1.环境安装
先新建一个环境吧
conda create -n hermes python=3.9
1.1 pip install -r requirements.txt
①
解决方法:去掉 fonttools
安装apex
根据官方指引安装
APEX介绍以及安装:英伟达(NVIDIA)训练深度学习模型神器APEX使用指南
我用的官方网站的安装
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
数据预处理
我的的自己的数据,总体步骤分两步:
①python dataset_conversion/xxx_3d.py
主要内容:重采样,裁剪(保留前景部分)
得到dataset.yaml
②python nii2npy.py
75,76行缺少一个括号,在逗号面前加
主要内容:HU值截断
CT:[-991,500]
MR:percentile_2 = np.percentile(img, 2, axis=None)
percentile_98 = np.percentile(img, 98, axis=None)
主要内容:归一化
保存为npy文件
!!!!!!要先确定好自己的patch_size!!!!!!!!!,这里默认pad到[128,128,128),我需要192,然后发现报错。。。。。
训练
①修改参数
config/universal/hermes_resunet_3d.yaml
我用自己的数据,主要修改:
data_root
tn: 13 # the number of task priors + one null task token
mn: 3 # the number of modality priors
training_size
dataset_name_lis
dataset_classes_list #去掉背景
window_size #与training_size一致
training/dataset/dim3/dataset_config.py
根据需求修改
train_test_split#数据集划分
dataset_lab_map#任务编码
dataset_modality_map#模态编码
dataset_sample_weight#采样权重,不知道作者怎么设置的权重,作者说小数据集有高权重,我靠感觉设置的自己数据集的权重
dataset_aug_prob #数据增强参数
其中的dataset_lab_map是任务编码,不是分割目标编码哦,和Uniseg不同,这里是将一个分割目标视作为一个任务,而Uniseg是一个数据集一个任务。采样也和Uniseg不同,Uniseg是一个Batch_size同一个任务,而这里是混合的多个任务。
直接运行,没有报错
python train.py --batch_size 8 --gpu 0,1,2,3
查看tensorboard
按ctrl+shift+p,输入并选择 TensorBoard: Start TensorBoard,提供日志目录(log/universal/test),即可。
测试
我需要可视化结果,所以自己写了一个,可能有错哦
import os
import argparse
import yaml
from training.dataset.utils import get_dataset
from torch.utils import data
from model.utils import get_model
import numpy as np
import torch
from inference.utils import get_inference
from tqdm import tqdm
import SimpleITK as sitk
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str,
default='universal', help='model_name')
parser.add_argument('--gpu', type=str, default='4', help='GPU to use')
parser.add_argument('--model_path', type=str, default='./exp', help='log path')
parser.add_argument('--unique_name', type=str, default='hermes_dualattention', help='unique experiment name')
parser.add_argument('--result_path', type=str, default='./result', help='log path')
parser.add_argument('--prefix', type=str, default='test', help='prefix')
parser.add_argument('--dimension', type=str, default='3d', help='2d model or 3d model')
parser.add_argument('--dataset', type=str, default='universal', help='dataset name')
if __name__ == '__main__':
# parse the arguments
args = parser.parse_args()
config_path = 'config/universal/hermes_resunet_3d.yaml'
if not os.path.exists(config_path):
raise ValueError("The specified configuration doesn't exist: %s"%config_path)
with open(config_path, 'r') as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
for key, value in config.items():
setattr(args, key, value)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
net = get_model(args)
model_folder = "{}/{}/{}".format(args.model_path, args.model, args.unique_name)
save_mode_path = os.path.join(model_folder, 'best.pth')
# 加载模型检查点
checkpoint = torch.load(save_mode_path)
net.load_state_dict(checkpoint['model_state_dict'])
inference = get_inference(args)
testLoader_list = []
for dataset_name in args.dataset_name_list: #加载各个数据集
print("task",dataset_name)
save_path = os.path.join(args.result_path,dataset_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
testset = get_dataset(args, dataset_name_list=[dataset_name], mode='test')
print('length of testset:', len(testset))
testLoader = data.DataLoader(
testset,
batch_size=1, # has to be 1 sample per gpu, as the input size of 3D input is different
shuffle=False,
sampler=None,
pin_memory=True,
num_workers=0
)
iterator = tqdm(testLoader)
for (images, _, tgt_idx, mod_idx, spacing,img_name) in iterator:
images= images.cuda().float()
tgt_idx = tgt_idx.cuda().long()
mod_idx = mod_idx.cuda().long().unsqueeze(1)
C = torch.nonzero(tgt_idx.squeeze(0)+1).shape[0] #实际类别数
pred = inference(net, images, tgt_idx, mod_idx, args)#返回[1,classes,D,H,W]
pred = pred.to(torch.int8)
pred = pred[0, :, :, :, :]
pred = pred[:C, :, :, :]
pred = torch.argmax(pred, dim=0)
torch.cuda.empty_cache()
pred_np = pred.cpu().numpy().astype(np.int8)
print(f"Prediction shape: {pred_np.shape}")
print(f"Prediction unique values: {np.unique(pred_np)}")
sitk_label = sitk.GetImageFromArray(pred_np)
spacing_list = spacing.squeeze(0).tolist()
print(f"Spacing: {spacing_list}")
sitk_label.SetSpacing(spacing_list)
output_filename = os.path.splitext(os.path.basename(img_name[0]))[0]
save_label_path = os.path.join(save_path, '{}.nii.gz'.format(output_filename))
sitk.WriteImage(sitk_label, save_label_path)
loaded_label = sitk.ReadImage(save_label_path)
print(f"Loaded label shape: {sitk.GetArrayFromImage(loaded_label).shape}")
需要修改training/dataset/dim3/dataset_universal.py 里的class UniversalDataset(Dataset),修改为:
if self.mode == 'train':
return tensor_img, tensor_lab.to(torch.int8), torch.from_numpy(tgt).to(tensor_img.device), torch.from_numpy(np.array(self.mod_list[idx])).to(tensor_img.device)
elif self.mode == 'val':
return tensor_img, temp_lab, torch.from_numpy(tgt), torch.from_numpy(np.array(self.mod_list[idx])), np.array(self.spacing_list[idx])
else:
return tensor_img, temp_lab, torch.from_numpy(tgt), torch.from_numpy(np.array(self.mod_list[idx])), np.array(self.spacing_list[idx]),self.img_list[idx]
部分代码解读
有点乱
以CNN为框架的话,会进行五次融合,分别是下采样最后两层,瓶颈层,以及上采样最开始两层。
采用交叉注意力机制进行融合。
任务以及模态的prompt使用可学习参数,其中任务prompt大小为[total_tasknum+1,c],模态为[total_modalitynum,10,c],其中10是固定的。