医学图像处理模板 pytorch_lightning+monai

模板 

# -*-coding:utf-8-*-
import pytorch_lightning as pl
from monai import transforms
import numpy as np
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from monai.config import KeysCollection
from monai.utils import set_determinism

pl.seed_everything(42)
set_determinism(42)


class Config(object):
    pass


class ObserveShape(transforms.MapTransform):
    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
        super(ObserveShape, self).__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            print(d[key].shape)
            # 输入是(X,Y,Z)
        return d


# 适用于分割有重叠的部分
class ConvertLabeld(transforms.MapTransform):
    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
        super(ConvertLabeld, self).__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            img = d[key]
            res = []
            # 将 tumor和pancreas合并成pancreas
            res.append(np.logical_or(img == 1, img == 2))
            res.append(img == 2)  # tumor通道

            res = np.stack(res, axis=0)
            # res = np.concatenate(res, axis=0)
            res = res.astype(np.float)
            d[key] = res
        return d


class LitsDataSet(pl.LightningDataModule):
    def __init__(self, cfg=Config()):
        super(LitsDataSet, self).__init__()
        pass

    def prepare_data(self):
        self.get_init()
        pass

    # 划分训练集,验证集,测试集以及定义数据预处理和增强,
    def setup(self, stage=None) -> None:
        self.split_dataset()
        self.get_preprocess()
        pass

    def train_dataloader(self):
        pass

    def val_dataloader(self):
        pass

    def test_dataloader(self):
        pass

    # 定义训练集和测试集的transformer,包括读取数据,数据增强,像素体素归一化等等
    def get_preprocess(self):
        pass

    def get_init(self):
        pass

    def split_dataset(self):
        pass


class Lung(pl.LightningModule):
    # 定义网络模型,损失函数类,metrics类以及后处理标签函数等
    def __init__(self, cfg=Config()):
        super(Lung, self).__init__()
        pass

    def configure_optimizers(self):
        pass

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

    def test_step(self, batch, batch_idx):
        pass

    def training_epoch_end(self, outputs):
        pass

    def validation_epoch_end(self, outputs):
        pass

    def test_epoch_end(self, outputs):
        pass

    # training_epoch_end,valid_epoch_end,test_epoch_end共同步骤可写在此函数中
    def shared_epoch_end(self, outputs, loss_key):
        pass

    # training_step,valid_step,test_step共同步骤可写在此函数中
    def shared_step(self, y_hat, y):
        pass


data = LitsDataSet()
model = Lung()

early_stop = EarlyStopping()

cfg = Config()
check_point = ModelCheckpoint()
trainer = pl.Trainer(
    progress_bar_refresh_rate=10,
    gpus=1,
    # auto_select_gpus=True, # 这个参数针对混合精度训练时,不能使用

    # auto_lr_find=True,
    auto_scale_batch_size=True,
    callbacks=[early_stop, check_point],
    precision=16,  # 16为指定半精度训练,
    accumulate_grad_batches=4,
    num_sanity_val_steps=0,
    log_every_n_steps=10,
    auto_lr_find=True
)
trainer.fit(model, data)

 基于MSD的pancreas数据集的分割例子:

# -*-coding:utf-8-*-
import os
import random

import torch
from torch import nn, functional as F, optim
import monai
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from monai import transforms
from monai.transforms import Compose
from monai.transforms import LoadImaged, LoadImage
from monai.data import Dataset, SmartCacheDataset
from torch.utils.data import DataLoader, random_split
from glob import glob
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from monai.config import KeysCollection
from torch.utils.data import random_split
from SwinUnet_3D import swinUnet_t_3D
from monai.losses import DiceLoss, DiceFocalLoss, DiceCELoss, FocalLoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.data import decollate_batch, list_data_collate
from monai.networks.utils import one_hot
from einops import rearrange
from torchmetrics.functional import dice_score
from torchmetrics import IoU, Accuracy
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureChannelFirstd,
    EnsureTyped,
    EnsureType,
    ConvertToMultiChannelBasedOnBratsClassesd,
    SpatialPadd,
    ScaleIntensityRangePercentilesd,
    ScaleIntensityRanged,
    CropForegroundd,
    RandCropByPosNegLabeld
)

pl.seed_everything(42)
set_determinism(42)


class Config(object):
    data_path = r'D:\Caiyimin\Dataset\MSD\Pancreas'

    FinalShape = [160, 160, 160]
    window_size = [5, 5, 5]  # 针对siwnUnet3D而言的窗口大小,FinalShape[i]能被window_size[i]数整除
    in_channels = 1

    # 数据集原始尺寸(体素间距为1.0时)中位数为(411,411,240)
    # 体素间距为1时,z轴最小尺寸为127,最大为499
    ResamplePixDim = (2.0, 2.0, 1.0)
    HuMax = 50 + 350 / 2
    HuMin = 35 - 350 / 2
    low_percent = 0.5
    upper_percent = 99.5

    train_ratio, val_ratio, test_ratio = [0.8, 0.2, 0.0]
    BatchSize = 1
    NumWorkers = 0

    n_classes = 2  # 括pancreas和cancer这两个通道

    lr = 3e-5  # 学习率

    back_bone_name = 'SwinUnet'
    # back_bone_name = 'Unet3D'
    # back_bone_name = 'UnetR'

    # 滑动窗口推理时使用
    roi_size = FinalShape
    slid_window_overlap = 0.5


class ObserveShape(transforms.MapTransform):
    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
        super(ObserveShape, self).__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            print(d[key].shape)
            # 输入是(X,Y,Z)
        return d


class ConvertLabeld(transforms.MapTransform):
    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
        super(ConvertLabeld, self).__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            img = d[key]
            res = []
            # 将 tumor和pancreas合并成pancreas
            res.append(np.logical_or(img == 1, img == 2))
            res.append(img == 2)  # tumor通道

            res = np.stack(res, axis=0)
            # res = np.concatenate(res, axis=0)
            res = res.astype(np.float)
            d[key] = res
        return d


class LitsDataSet(pl.LightningDataModule):
    def __init__(self, cfg=Config()):
        super(LitsDataSet, self).__init__()
        self.cfg = cfg
        self.data_path = cfg.data_path
        self.train_path = os.path.join(cfg.data_path, 'imagesTr')
        self.label_tr_path = os.path.join(cfg.data_path, 'labelsTr')
        self.test_path = os.path.join(cfg.data_path, 'imagesTs')

        self.train_dict = []
        self.val_dict = []
        self.test_dict = []

        self.train_set = None
        self.val_set = None
        self.test_set = None

        self.train_process = None
        self.val_process = None

    def prepare_data(self):
        train_x, train_y, test_x = self.get_init()
        for x, y in zip(train_x, train_y):
            info = {'image': x, 'label': y}
            self.train_dict.append(info)

        for x in test_x:
            info = {'image': x}
            self.test_dict.append(info)
        self.get_preprocess()

    # 划分训练集,验证集,测试集以及定义数据预处理和增强,
    def setup(self, stage=None) -> None:
        self.split_dataset()
        self.train_set = Dataset(self.train_dict, transform=self.train_process)
        self.val_set = Dataset(self.val_dict, transform=self.val_process)
        self.test_set = Dataset(self.test_dict, transform=self.val_process)

    def train_dataloader(self):
        cfg = self.cfg
        return DataLoader(self.train_set, batch_size=cfg.BatchSize,
                          num_workers=cfg.NumWorkers,
                          collate_fn=list_data_collate)

    def val_dataloader(self):
        cfg = self.cfg
        return DataLoader(self.val_set, batch_size=cfg.BatchSize, num_workers=cfg.NumWorkers)

    def test_dataloader(self):
        cfg = self.cfg
        return DataLoader(self.test_set, batch_size=cfg.BatchSize, num_workers=cfg.NumWorkers)

    def get_preprocess(self):
        cfg = self.cfg
        self.train_process = Compose([
            LoadImaged(keys=['image', 'label']),
            EnsureChannelFirstd(keys=['image']),
            ConvertLabeld(keys='label'),
               
            Orientationd(keys=['image', 'label'], axcodes='RAS'),
            Spacingd(keys=['image', 'label'], pixdim=cfg.ResamplePixDim,
                     mode=('bilinear', 'nearest')),

            ScaleIntensityRanged(keys='image', a_min=cfg.HuMin, a_max=cfg.HuMax,
                                 b_min=0.0, b_max=1.0, clip=True),
            # CropForegroundd(keys=['image', 'label'], source_key='image'),
            SpatialPadd(keys=['image', 'label'], spatial_size=cfg.FinalShape),
            RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label',
                                   spatial_size=cfg.FinalShape,
                                   pos=1, neg=1, num_samples=1, image_key='image', ),

            RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=2),

            RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
            EnsureTyped(keys=['image', 'label']),

        ])

        self.val_process = Compose([
            LoadImaged(keys=['image', 'label']),

            EnsureChannelFirstd(keys=['image']),
            ConvertLabeld(keys='label'),

            Orientationd(keys=['image', 'label'], axcodes='RAS'),
            Spacingd(keys=['image', 'label'], pixdim=cfg.ResamplePixDim,
                     mode=('bilinear', 'nearest')),

            ScaleIntensityRanged(keys='image', a_min=cfg.HuMin, a_max=cfg.HuMax,
                                 b_min=0.0, b_max=1.0, clip=True),
            # CropForegroundd(keys=['image', 'label'], source_key='image'),

            EnsureTyped(keys=['image', 'label']),
        ])

    def get_init(self):
        train_x = sorted(glob(os.path.join(self.train_path, '*.nii.gz')))
        train_y = sorted(glob(os.path.join(self.label_tr_path, '*.nii.gz')))
        test_x = sorted(glob(os.path.join(self.test_path, '*.nii.gz')))

        return train_x, train_y, test_x

    def split_dataset(self):
        cfg = self.cfg
        num = len(self.train_dict)
        train_num = int(num * cfg.train_ratio)
        val_num = int(num * cfg.val_ratio)
        test_num = int(num * cfg.test_ratio)
        if train_num + val_num + test_num != num:
            remain = num - train_num - test_num - val_num
            val_num += remain

        self.train_dict, self.val_dict, self.test_dict \
            = random_split(self.train_dict, [train_num, val_num, test_num])


class Lung(pl.LightningModule):
    def __init__(self, cfg=Config()):
        super(Lung, self).__init__()
        self.cfg = cfg
        if cfg.back_bone_name == 'SwinUnet':
            self.net = swinUnet_t_3D(window_size=cfg.window_size,
                                     num_classes=cfg.n_classes,
                                     in_channel=cfg.in_channels, )
        else:
            from monai.networks.nets import UNETR, UNet
            if cfg.back_bone_name == 'UnetR':
                self.net = UNETR(in_channels=cfg.in_channels,
                                 out_channels=cfg.n_classes,
                                 img_size=cfg.FinalShape)
            else:
                self.net = UNet(spatial_dims=3, in_channels=1,
                                out_channels=cfg.n_classes,
                                channels=(32, 64, 128, 256, 512),
                                strides=(2, 2, 2, 2))

        self.loss_func = DiceLoss(smooth_nr=0, smooth_dr=1e-5,
                                  squared_pred=False,
                                  sigmoid=True)

        self.metrics = DiceMetric(include_background=True,
                                  reduction='mean_batch')
        self.post_pred = Compose([
            EnsureType(), Activations(sigmoid=True),
            AsDiscrete(threshold_values=True)
        ])

    def configure_optimizers(self):
        cfg = self.cfg
        opt = optim.AdamW(params=self.parameters(), lr=cfg.lr, eps=1e-7,
                          weight_decay=1e-5)

        # lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        #     opt, T_0=5, T_mult=1, )
        # return {'optimizer': opt, 'lr_scheduler': lr_scheduler, 'monitor': 'valid_loss'}

        return opt

    def training_step(self, batch, batch_idx):

        x = batch['image']
        y = batch['label']
        # y_hat = sliding_window_inference(x, roi_size=cfg.FinalShape,
        #                                  sw_batch_size=cfg.BatchSize,
        #                                  predictor=self.net,
        #                                  overlap=cfg.slid_window_overlap)
        y_hat = self.net(x)

        loss, dice = self.shared_step(y_hat=y_hat, y=y)
        p_dice, t_dice = dice[0], dice[1]
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_pancreas_dice', p_dice, prog_bar=True)
        self.log('train_tumor_dice', t_dice, prog_bar=True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        cfg = self.cfg
        x = batch['image']
        y = batch['label']
        y_hat = sliding_window_inference(x, roi_size=cfg.FinalShape, sw_batch_size=cfg.BatchSize, predictor=self.net,
                                         overlap=cfg.slid_window_overlap)
        loss, dice = self.shared_step(y_hat=y_hat, y=y)
        p_dice, t_dice = dice[0], dice[1]
        self.log('valid_loss', loss, prog_bar=True)
        self.log('valid_pancreas_dice', p_dice, prog_bar=True)
        self.log('valid_tumor_dice', t_dice, prog_bar=True)
        return {'loss': loss}

    def test_step(self, batch, batch_idx):
        cfg = self.cfg
        x = batch['image']
        y = batch['label']
        y_hat = sliding_window_inference(x, roi_size=cfg.FinalShape, sw_batch_size=1, predictor=self.net,
                                         overlap=cfg.slid_window_overlap)

        loss, dice = self.shared_step(y_hat=y_hat, y=y)
        p_dice, t_dice = dice[0], dice[1]
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_pancreas_dice', p_dice, prog_bar=True)
        self.log('test_tumor_dice', t_dice, prog_bar=True)
        return {'loss': loss}

    def training_epoch_end(self, outputs):
        losses, dice = self.shared_epoch_end(outputs, 'loss')
        p_dice, t_dice = dice[0], dice[1]
        self.log('train_mean_loss', losses, prog_bar=True)
        self.log('train_mean_pancreas_dice', p_dice, prog_bar=True)
        self.log('train_mean_tumor_dice', t_dice, prog_bar=True)

    def validation_epoch_end(self, outputs):
        losses, dice = self.shared_epoch_end(outputs, 'loss')
        p_dice, t_dice = dice[0], dice[1]
        self.log('valid_mean_loss', losses, prog_bar=True)
        self.log('valid_mean_pancreas_dice', p_dice, prog_bar=True)
        self.log('valid_mean_tumor_dice', t_dice, prog_bar=True)

    def test_epoch_end(self, outputs):
        losses, dice = self.shared_epoch_end(outputs, 'loss')
        p_dice, t_dice = dice[0], dice[1]
        self.log('valid_mean_loss', losses, prog_bar=True)
        self.log('valid_mean_pancreas_dice', p_dice, prog_bar=True)
        self.log('valid_mean_tumor_dice', t_dice, prog_bar=True)

    def shared_epoch_end(self, outputs, loss_key):
        losses = []
        for output in outputs:
            # loss = output['loss'].detach().cpu().numpy()
            loss = output[loss_key].item()
            losses.append(loss)

        losses = np.array(losses)
        losses = np.mean(losses)

        dice = self.metrics.aggregate()
        self.metrics.reset()

        dice = dice.detach().cpu().numpy()
        return losses, dice

    def shared_step(self, y_hat, y):
        loss = self.loss_func(y_hat, y)

        y_hat = [self.post_pred(it) for it in decollate_batch(y_hat)]
        y = decollate_batch(y)

        dice = self.metrics(y_hat, y)

        dice = torch.nan_to_num(dice)
        loss = torch.nan_to_num(loss)

        dice = torch.mean(dice, dim=0)
        return loss, dice


data = LitsDataSet()
model = Lung()

early_stop = EarlyStopping(
    monitor='valid_mean_loss',
    patience=10,
)

cfg = Config()
check_point = ModelCheckpoint(dirpath=f'./trained_models/{cfg.back_bone_name}',
                              save_last=False,
                              save_top_k=2, monitor='valid_mean_loss', verbose=True,
                              filename='{epoch}-{valid_loss:.2f}-{valid_mean_dice:.2f}')
trainer = pl.Trainer(
    progress_bar_refresh_rate=10,
    max_epochs=400,
    min_epochs=30,
    gpus=1,
    # auto_select_gpus=True, # 这个参数针对混合精度训练时,不能使用

    # auto_lr_find=True,
    auto_scale_batch_size=True,
    logger=TensorBoardLogger(save_dir=f'./logs', name=f'{cfg.back_bone_name}'),
    callbacks=[early_stop, check_point],
    precision=16,
    accumulate_grad_batches=4,
    num_sanity_val_steps=0,
    log_every_n_steps=10,
    auto_lr_find=True
)
trainer.fit(model, data)

适用于医学图像分割的SwinUnet3D源码暂时还不能公开,请将back_bone换成Unetr或者Unet3D等等

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
根据引用\[3\]的介绍,如果出现"No module named 'pytorch_lightning'"的错误,可能是因为没有正确安装pytorch_lightning模块。请确保已经正确安装了pytorch_lightning模块。可以使用以下命令进行安装: pip install pytorch_lightning 如果已经安装了pytorch_lightning模块,但仍然出现该错误,请确保已经正确导入了pytorch_lightning模块。可以使用以下语句进行导入: import pytorch_lightning 如果问题仍然存在,请检查您的环境配置和安装是否正确,并确保您的代码中没有拼写错误或其他语法错误。 #### 引用[.reference_title] - *1* [No module named ‘pytorch_lightning.utilities.distributed](https://blog.csdn.net/SPESEG/article/details/131530183)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [(已解决)ModuleNotFoundError: No module named ‘pytorch_lightning.metrics](https://blog.csdn.net/qq_43391414/article/details/124412694)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [如何解决ModuleNotFoundError: No module named ‘pytorch_lightning.metrics](https://blog.csdn.net/lyf6667/article/details/125673107)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值