[2504.06742] nnLandmark: A Self-Configuring Method for 3D Medical Landmark Detectionhttps://arxiv.org/abs/2504.06742故事的起因就是看到这篇文章之后,+1觉得地标检测的效果不错,希望我能改出来,方便以后分割肝八段的时候使用。虽然本人觉得原作者应该很快就能把代码发布出来,但谁知道会是什么时候呢,并且+1一直在催,所以就花了一些时间,将nnunet的框架捋了一遍,将其改成了nnlandmark的形式。
众所周知,nnunet是专用于图像分割的一个框架,且在医学图像分割方面表现卓越,其本质是对输入的图像进行分类,分成前景和背景;而nnlandmark是用于地标检测的,他的本质是对输入的图像进行热图回归,输出连续值,至于连续值是0~1,还是-1~1,还是0~100,就看自己的需求更改。
接下来进入正文,到底应该怎么改?根据文章[2504.06742] nnLandmark: A Self-Configuring Method for 3D Medical Landmark Detection的描述,大致分为以下几个步骤:(1)为了确保与nnU-Net的实验规划和预处理兼容,地标标签最初采用多类分割格式;(2)在训练过程中,这些分割标签被转换为热图,每个地标分配一个独立的通道,每个地标表示为高斯斑点——标准差σ = 4,并在0到1之间归一化;(3)在网络的最后一层,加入一个sigmoid激活函数,将预测的体素值限制在0到1之间,从而稳定热图回归训练。网络输出通道的数量设置为地标数量;(4)使用MSE损失函数来回归连续的热图值;(5)选择Adam优化器来处理多通道热图回归中的强烈类别不平衡问题(每个通道包含一个完整的体积,但前景仅是一个相对较小的高斯斑点);(6)在后处理过程中,通过识别通道的最大值从预测的热图中提取地标坐标。
1. 对应步骤(1)。这一步是对数据的修改,原始的地标标签大部分是以坐标的形式呈现的,在这一步中,需要将“坐标”转为“点”,并对不同的点赋予不同的标签值(1、2、3、...、n,n表示地标的数量,不从0开始的原因是,0表示背景)。
2. 对应步骤(2)-(5),修改训练过程。
2.1. 首先是对nnUNetTrainer文件的修改,路径在nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py,但这边建议最好不要在原文件进行修改,而是新建一个文件nnUNetTrainerlandmark.py。路径可以放在nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainerlandmark.py,下面是全部代码。
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager
from torch import nn
import inspect
import multiprocessing
import os
import shutil
import sys
import warnings
from copy import deepcopy
from datetime import datetime
from time import time, sleep
from typing import Tuple, Union, List
import numpy as np
import torch
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform
from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform
from batchgeneratorsv2.transforms.utils.random import RandomTransform
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform,ConvertSegToGaussianHeatmapTransform
from torch import autocast, nn
from torch import distributed as dist
from torch._dynamo import OptimizedModule
from torch.cuda import device_count
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes
from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.sliding_window_prediction import compute_gaussian
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results
from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D
from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D,nnUNetDataLoader_3D
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset
from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss, MSEHeatmapRegressionLoss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.utilities.collate_outputs import collate_outputs
from nnunetv2.utilities.crossval_split import generate_crossval_split
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
class nnUNetTrainerlandmark(nnUNetTrainer):
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.network.parameters(),
lr=self.initial_lr,
weight_decay=self.weight_decay
)
lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
print('use Adam!')
return optimizer, lr_scheduler
def _build_loss(self):
if self.label_manager.has_regions:
loss = DC_and_BCE_loss({},
{'batch_dice': self.configuration_manager.batch_dice,
'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
use_ignore_label=self.label_manager.ignore_label is not None,
dice_class=MemoryEfficientSoftDiceLoss)
else:
# loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
# 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
# ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
loss = MSEHeatmapRegressionLoss(ignore_label=self.label_manager.ignore_label, reduction='mean')
if self._do_i_compile():
loss.dc = torch.compile(loss.dc)
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
if self.is_ddp and not self._do_i_compile():
# very strange and stupid interaction. DDP crashes and complains about unused parameters due to
# weights[-1] = 0. Interestingly this crash doesn't happen with torch.compile enabled. Strange stuff.
# Anywho, the simple fix is to set a very low weight to this.
weights[-1] = 1e-6
else:
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
def get_dataloaders(self):
patch_size = self.c