医学图像处理算法学习——3DUnetCNN(1)

医学图像处理算法学习——3DUnetCNN(1)

第一章 3DUnetCNN(1)——原demo部署运行于源码分析



前言

作为初入医学图像处理领域的新手,本博客旨在记录自己对于3DUnetCNN学习过程中遇到的问题,以及自己对代码的理解与思考,包括后续如何将自己的数据集转换为该程序可用的数据集,供自己后续回看,也供大家一起学习参考。


一、源码与数据集

代码源码下载自GitHub:link
本代码数据集采用的是BraTS2020图像库,下载地址:link
需分别下载训练集与验证集,下载后将其按下述格式放置于examples文件夹中
在这里插入图片描述

二、运行训练demo

1.搭建环境

首先创建其自己的虚拟环境
安装安装所需库

代码如下(示例):

pip install -r 3DUnetCNN/requirements.txt

==注意:==如未安装torch,需先安装torch(GPU版本)
安装好后在cmd运行下述代码确认是否正确安装:

python
import torch
print(torch.cuda.is_available())
#返回True说明安装成功

2.运行代码

代码如下(示例):

cd examples/brats2020 #打开至数据集路径
#注意源代码一定要在该路径下运行,否则会找不到数据集报错
python /path/to/unet3d/scripts/train.py --config_filename brats2020_config.json
#/path/to/unet3d/scripts/train.py替换为scripts/train.py所在的绝对路径

注意:
1.该demo运行时遇到一些问题,因此对其中部分文件进行了微调,在后续会指出。
2.注意源代码一定要在examples/brats2020路径下运行,否则会找不到数据集报错。
3.BraTS2020数据集较大,若不修改直接运行demo,需要较大的GPU内存,否则会报错torch.cuda.OutOfMemoryError: CUDA out of memory.


三、代码分析

在运行成功后,本部分旨在对3DUnetCNN源码进行分析,了解其结构与每部分代码功能,方便后续修改与确定自己的数据集,在本段中为主要代码添加了注释

1.代码结构

在这里插入图片描述
整体代码结构如上图所示,目前进了解了部分代码,其中examples/brats2020中存放的是数据集,brats2020_config.json为训练数据集的信息。
在这里插入图片描述在这里插入图片描述unet3d中为整个算法的主体部分,其中scripts/train.py与scripts/perdict.py为训练和预测的主函数,其他函数为其中各个单元。

2.train.py

注意: 由于train.py中对其他各个文件进行了引用与跳转,因此在下面讲解中将按照逻辑跳转不同文件进行分析,同时指出在运行demo时进行修改的部分。

首先是引用库,由于虽然unet3d文件夹中包含了__init__.py文件,但是仍报错找不到unet3d库,因此在程序中手动添加路径代码

sys.path.append('C:\\Users\\15496\\Desktop\\3DUnetCNN-master')
#路径为算法所在绝对路径

在这里插入图片描述

接下来是最下端的主函数部分,其重点内容为parse_args()与run(),parse_args()是参数定义函数。run()是训练算法,在main()中,首先定义参数,然后执行运行函数。后续将对这两个函数进行详细讲解。
在这里插入图片描述首先是parse_args()函数:

def parse_args():
    # 创建一个命令行参数解析器对象
    parser = argparse.ArgumentParser()
    # 添加一个命令行参数,指定配置文件的路径,参数名为--config_filename,必须提供该参数
    parser.add_argument("--config_filename", required=True,
                        help="JSON configuration file specifying the parameters for model training.")
    # 添加一个命令行参数,指定输出目录的路径,参数名为--output_dir,是可选参数
    parser.add_argument("--output_dir", required=False,
                        help="Output directory where all the outputs will be saved. "
                             "Defaults to the directory of the configuration file.")
    # 添加一个命令行参数,指定是否只生成交叉验证配置文件的标志,参数名为--setup_crossval_only
    parser.add_argument("--setup_crossval_only", action="store_true", default=False,
                        help="Only write the cross-validation configuration files. "
                             "If selected, training will not be run. Instead the filenames will be split into "
                             "folds and modified configuration files will be written to the working directory. "
                             "This is useful if you want to submit training folds to an HPC scheduler system.")
    # 添加一个命令行参数,指定预训练模型文件的路径,参数名为--pretrained_model_filename
    parser.add_argument("--pretrained_model_filename",
                        help="If this filename exists prior to training, the model will be loaded from the filename. "
                             "Default is '{output_dir}/{config_basename}/model.pth'. "
                             "The default behavior is to use flexible loading of the model where not all the "
                             "model layers/weights have to match. "
                             "If training is interrupted and resumed, if a pretrained model is defined "
                             "the pretrained model will be used instead of loading "
                             "the model that was being trained initially. Therefore, if you are resuming training "
                             "it is best to not set the pretrained_model_filename.",
                        required=False)
    # 添加一个命令行参数,指定训练日志文件的路径,参数名为--training_log_filename
    parser.add_argument("--training_log_filename",
                        help="CSV filename to save the to save the training and validation results for each epoch. "
                             "Default is '{output_dir}/{config_basename}/training_log.csv'",
                        required=False)
    # 添加一个命令行参数,覆盖配置文件中的批量大小,参数名为--batch_size
    parser.add_argument("--batch_size", help="Override the batch size from the config file.", type=int)
    # 添加一个命令行参数,指定是否开启调试模式的标志,参数名为--debug
    parser.add_argument("--debug", action="store_true", default=False,
                        help="Raises an error if a training file is not found. The default is to silently skip"
                             "any training files that cannot be found. Use this flag to debug the config for finding"
                             "the data.")
    # 调用一个自定义函数add_machine_config_to_parser,向命令行参数解析器添加机器配置相关的参数
    add_machine_config_to_parser(parser)
    # 添加一个命令行参数,指定用于调试目的的示例输入/输出对的数量,参数名为--n_examples
    parser.add_argument("--n_examples", type=int, default=1,
                        help="Number of example input/output pairs to write to file for debugging purposes. "
                             "(default = 1)")
    # 解析命令行参数,并将解析的结果存储在名为args的对象中
    args = parser.parse_args()
    # 返回解析的命令行参数对象
    return args

parse_args()参数分析:
config_filename指配置文件所在路径;
output_dir指输出所在路径,如未特殊指定默认与config_filename相同文件夹;
set_crossval_only为生成交叉验证配置文件的标志,默认是False,当为True时,只生成配置文件,不进行训练,此参数在后续run()中会说明;
pretrained_model_filename指预训练模型文件的路径,默认路径为‘{output_dir}/{config_basename}/model.pth’,可以没有,比如在直接下载下来的代码中就不包含预训练模型;
training_log_filename是训练日志文件的路径,默认为’{output_dir}/{config_basename}/training_log.csv’;
batch_size是训练batch大小,在config文件中已经规定,可以进行修改;
debug是调试模式,默认为False,True时进入调试模式;
n_examples是用于调试的参数,表示用于调试目的的示例输入/输出对的数量。

该段代码中–config_filename为必输参数,因此需要在运行程序时输入路径:

python /path/to/unet3d/scripts/train.py --config_filename brats2020_config.json

考虑到每次都需要输入路径较为麻烦,因此对其进行修改:

    parser.add_argument("--config_filename", required=False, default='C:\\Users\\15496\\Desktop\\3DUnetCNN-master\\examples\\brats2020\\brats2020_config.json',
                        help="JSON configuration file specifying the parameters for model training.")

其中将required修改为False,default后路径为brats2020_config.json所在绝对路径。

接下来说明一下brats2020_config.json文件:
config文件内包含了训练所需的所有参数:

“model”:
设定了输入输出、每一层的strides、filters以及卷积核的大小(333)、上采样核的大小(222)
“optimizer”:Adam
“loss”:DiceLoss
“cross_validation":交叉验证折数folds=5,随机数种子seed=25
“scheduler”:学习率设置
“dataset”::数据集设置
“training”:训练参数设置
以及训练集和测试集的名称

"scheduler": {
	#学习率调度器
    "name": "ReduceLROnPlateau",	
    #用于控制学习率降低的参数。它表示如果模型在验证集上的性能在连续10个训练周期(epoch)中都没有显著改善,那么学习率将被降低。这是为了应对训练过程中的性能停滞或收敛到局部最小值的情况。
    "patience": 10,		
    #学习率降低的因子。一旦达到了"patience"指定的停滞期,学习率将乘以这个因子,从而降低学习率的大小。在这里,学习率将减小到原来的一半。					
    "factor": 0.5,
    #学习率的最小值。学习率将不会降低到比这个值更小。这是为了确保学习率不会变得过小,防止训练过程过于缓慢。
    "min_lr": 1e-08
"dataset": {
    "name": "SegmentationDatasetPersistent",
    #调整数据集内图片大小
    "desired_shape": [
        128,
        128,
        128
    ],
    #指定数据集类别标签
    "labels": [
        2,
        1,
        4
    ],
    #(这个目前不太清楚)可能表示是否要设置类别标签的层次结构。如果设置为true,那么类别标签可能被组织成一个层次结构,而不仅仅是简单的类别标签。
    "setup_label_hierarchy": true,
    #图像归一化
    "normalization": "NormalizeIntensityD",
    "normalization_kwargs": {
        "channel_wise": true,
        "nonzero": false
    },
    #是否对数据进行重新采样。
    "resample": true,
    #是否在预处理中对图像进行前景裁剪。前景裁剪是一种常见的图像分割预处理技术,用于移除图像边缘周围的背景部分,以减小输入图像的大小,从而加速训练。
    "crop_foreground": true
"training": {
    "batch_size": 1,	# batch大小,表示逐个样本训练
    "validation_batch_size": 1,	#验证集batch大小
    "amp": false,			#是否启用混合精度训练,通过使用低位数的浮点数来加速模型训练。
    #表示早停止(early stopping)的耐心度,即在多少个训练周期内没有性能改善时停止训练。null表示没有设置早停止,训练将继续执行指定的训练周期数("n_epochs")。
    "early_stopping_patience": null,
    #这个参数指定了模型训练的总训练周期数。模型将在数据集上进行250个训练周期,然后停止(除非启用了早停止)。
    "n_epochs": 250,
    #这个参数可能表示模型保存的频率。如果设置为null,表示不会在特定的训练周期上保存模型。如果设置为一个正整数值,比如10,那么每10个训练周期将保存一次模型。
    "save_every_n_epochs": null,
    #这个参数可能表示要保存的最近模型的数量。如果设置为null,表示只保存最终模型。如果设置为一个正整数值,比如5,那么将保存最近的5个模型。
    "save_last_n_models": null,
    #这个参数表示是否要保存性能最佳的模型。如果设置为true,系统将在验证集上跟踪模型性能,并保存在验证集上性能最佳的模型。
    "save_best": true

最后是run()函数:

def run(config_filename, output_dir, namespace):
    # 打印配置文件的路径
    print("Config: ", config_filename)
    # 加载配置文件中的 JSON 数据,并将其存储在名为 config 的变量中
    config = load_json(config_filename)
    # 从配置文件中加载文件名
    load_filenames_from_config(config)
    # 根据配置文件的基本名称创建工作目录,并确保该目录存在
    work_dir = os.path.join(output_dir, os.path.basename(config_filename).split(".")[0])
    print("Work Dir:", work_dir)
    os.makedirs(work_dir, exist_ok=True)
    # 如果配置中包含交叉验证信息
    if "cross_validation" in config:
        # call parent function through each fold of the training set
        # 调用 setup_cross_validation 函数,为交叉验证设置配置
        cross_validation_config = config.pop("cross_validation")
        for _config, _config_filename in setup_cross_validation(config,
                                                                work_dir=work_dir,
                                                                n_folds=in_config("n_folds",
                                                                                  cross_validation_config,
                                                                                  5),
                                                                random_seed=in_config("random_seed",
                                                                                      cross_validation_config,
                                                                                      25)):
            # 如果不是仅设置交叉验证而不运行训练,则递归运行 run 函数
            if not namespace.setup_crossval_only:
                print("Running cross validation fold:", _config_filename)
                run(_config_filename, work_dir, namespace)
            else:
                # 否则,仅设置交叉验证
                print("Setup cross validation fold:", _config_filename)
    else:
        # run the training
        # 否则,执行模型训练
        # 获取系统配置信息
        system_config = get_machine_config(namespace)

        # set verbosity
        # 设置详细程度(verbosity),如果启用了调试模式
        if namespace.debug:
            if "dataset" not in config:
                config["dataset"] = dict()
            config["dataset"]["verbose"] = namespace.debug
            warnings.filterwarnings('error')

        # Override the batch size from the config file
        # 覆盖配置文件中的批量大小(batch size)
        if namespace.batch_size:
            warnings.warn(RuntimeWarning('Overwriting the batch size from the configuration file (batch_size={}) to '
                                         'batch_size={}'.format(config["training"]["batch_size"], namespace.batch_size)))
            config["training"]["batch_size"] = namespace.batch_size
        # 指定模型文件的路径
        model_filename = os.path.join(work_dir, "model.pth")
        print("Model: ", model_filename)
        # 指定训练日志文件的路径
        if namespace.training_log_filename:
            training_log_filename = namespace.training_log_filename
        else:
            training_log_filename = os.path.join(work_dir, "training_log.csv")
        print("Log: ", training_log_filename)
        # 检查标签层次结构
        label_hierarchy = check_hierarchy(config)
        # 加载数据集类
        dataset_class = load_dataset_class(config["dataset"], cache_dir=os.path.join(work_dir, "cache"))
        # 构建训练和验证数据加载器,以及需要监测的指标
        training_loader, validation_loader, metric_to_monitor = build_data_loaders_from_config(config,
                                                                                               system_config,
                                                                                               work_dir,
                                                                                               dataset_class)
        # 指定预训练模型文件的路径
        pretrained = namespace.pretrained_model_filename
        if pretrained:
            pretrained = os.path.abspath(pretrained)
        else:
            pretrained = model_filename
        # 构建或加载模型
        model = build_or_load_model_from_config(config,
                                                pretrained,
                                                system_config["n_gpus"])
        # 加载损失函数
        criterion = load_criterion_from_config(config, n_gpus=system_config["n_gpus"])
        # 构建优化器
        optimizer = build_optimizer(optimizer_name=config["optimizer"].pop("name"),
                                    model_parameters=model.parameters(),
                                    **config["optimizer"])
        # 构建学习率调度器
        scheduler = build_scheduler_from_config(config, optimizer)
        # 运行模型训练
        run_training(model=model.train(), optimizer=optimizer, criterion=criterion,
                     n_epochs=in_config("n_epochs", config["training"], 1000),
                     training_loader=training_loader, validation_loader=validation_loader,
                     model_filename=model_filename,
                     training_log_filename=training_log_filename,
                     metric_to_monitor=metric_to_monitor,
                     early_stopping_patience=in_config("early_stopping_patience", config["training"], None),
                     save_best=in_config("save_best", config["training"], True),
                     n_gpus=system_config["n_gpus"],
                     save_every_n_epochs=in_config("save_every_n_epochs", config["training"], None),
                     save_last_n_models=in_config("save_last_n_models", config["training"], None),
                     amp=in_config("amp", config["training"], None),
                     scheduler=scheduler,
                     samples_per_epoch=in_config("samples_per_epoch", config["training"], None))
        # 为推断过程构建数据加载器,并将结果保存到相应的目录中
        for _dataloader, _name in build_inference_loaders_from_config(config,
                                                                      dataset_class=dataset_class,
                                                                      system_config=system_config):
            prediction_dir = os.path.join(work_dir, _name)
            os.makedirs(prediction_dir, exist_ok=True)
            volumetric_predictions(model=model,
                                   dataloader=_dataloader,
                                   prediction_dir=prediction_dir,
                                   interpolation="trilinear",
                                   resample=in_config("resample", config["dataset"], False))

run()函数的注释如上所示,整体流程是这样的:
1.首先加载配置文件,并创建所需工作目录,如‘{output_dir}/{config_basename}’;

# 打印配置文件的路径
print("Config: ", config_filename)
# 加载配置文件中的 JSON 数据,并将其存储在名为 config 的变量中
config = load_json(config_filename)
# 从配置文件中加载文件名
load_filenames_from_config(config)
# 根据配置文件的基本名称创建工作目录,并确保该目录存在
work_dir = os.path.join(output_dir, os.path.basename(config_filename).split(".")[0])
print("Work Dir:", work_dir)
os.makedirs(work_dir, exist_ok=True)

2.然后如果配置文件中存在交叉验证信息,说明程序采用交叉验证进行训练(由于医学数据集较小,一般都采用交叉验证来进行训练),此时如果setup_crossval_only为False则会训练,如果为True则只设置交叉验证,不训练;

    # 如果配置中包含交叉验证信息
    if "cross_validation" in config:
        # call parent function through each fold of the training set
        # 调用 setup_cross_validation 函数,为交叉验证设置配置
        cross_validation_config = config.pop("cross_validation")
        for _config, _config_filename in setup_cross_validation(config,
                                                                work_dir=work_dir,
                                                                n_folds=in_config("n_folds",
                                                                                  cross_validation_config,
                                                                                  5),
                                                                random_seed=in_config("random_seed",
                                                                                      cross_validation_config,
                                                                                      25)):
            # 如果不是仅设置交叉验证而不运行训练,则递归运行 run 函数
            if not namespace.setup_crossval_only:
                print("Running cross validation fold:", _config_filename)
                run(_config_filename, work_dir, namespace)
            else:
                # 否则,仅设置交叉验证
                print("Setup cross validation fold:", _config_filename)

然后开始训练模型:
3.设置batch size

# Override the batch size from the config file
# 覆盖配置文件中的批量大小(batch size)
if namespace.batch_size:
warnings.warn(RuntimeWarning('Overwriting the batch size from the configuration file (batch_size={}) to '
                                         'batch_size={}'.format(config["training"]["batch_size"], namespace.batch_size)))
config["training"]["batch_size"] = namespace.batch_size

4.指定模型文件与训练日志文件的路径

# 指定模型文件的路径
model_filename = os.path.join(work_dir, "model.pth")
print("Model: ", model_filename)
# 指定训练日志文件的路径
if namespace.training_log_filename:
    training_log_filename = namespace.training_log_filename
else:
    training_log_filename = os.path.join(work_dir, "training_log.csv")
print("Log: ", training_log_filename)

5.加载数据集

# 检查标签层次结构
label_hierarchy = check_hierarchy(config)
# 加载数据集类
dataset_class = load_dataset_class(config["dataset"], cache_dir=os.path.join(work_dir, "cache"))
# 构建训练和验证数据加载器,以及需要监测的指标
training_loader, validation_loader, metric_to_monitor = build_data_loaders_from_config(config,
                                                                                       system_config,
                                                                                       work_dir,
                                                                                       dataset_class)

6.构建各个参数

# 构建或加载模型
model = build_or_load_model_from_config(config,
                                        pretrained,
                                        system_config["n_gpus"])
# 加载损失函数
criterion = load_criterion_from_config(config, n_gpus=system_config["n_gpus"])
# 构建优化器
optimizer = build_optimizer(optimizer_name=config["optimizer"].pop("name"),
                            model_parameters=model.parameters(),
                            **config["optimizer"])
# 构建学习率调度器
scheduler = build_scheduler_from_config(config, optimizer)

7.运行模型训练

# 运行模型训练
run_training(model=model.train(), optimizer=optimizer, criterion=criterion,
             n_epochs=in_config("n_epochs", config["training"], 1000),
             training_loader=training_loader, validation_loader=validation_loader,
             model_filename=model_filename,
             training_log_filename=training_log_filename,
             metric_to_monitor=metric_to_monitor,
             early_stopping_patience=in_config("early_stopping_patience", config["training"], None),
             save_best=in_config("save_best", config["training"], True),
             n_gpus=system_config["n_gpus"],
             save_every_n_epochs=in_config("save_every_n_epochs", config["training"], None),
             save_last_n_models=in_config("save_last_n_models", config["training"], None),
             amp=in_config("amp", config["training"], None),
             scheduler=scheduler,
             samples_per_epoch=in_config("samples_per_epoch", config["training"], None))

结合之前提到的brats2020_config.json讲解,接下来来看一看第6,7两点中的代码:
首先是6构建各个参数:该部分所引用的函数均在script_utils.py中:

构建或加载模型
model = build_or_load_model_from_config(config, pretrained,system_config[“n_gpus”])

def build_or_load_model_from_config(config, model_filename, n_gpus, strict=False):
    return build_or_load_model(config["model"].pop("name"), model_filename, n_gpus=n_gpus, **config["model"],
                               strict=strict)

加载模型加载的为brats2020_config.json中"model"部分参数,通过修改即可修改其模型。

加载损失函数
criterion = load_criterion_from_config(config, n_gpus=system_config[“n_gpus”])

def load_criterion_from_config(config, n_gpus):
    return load_criterion(config['loss'].pop("name"), n_gpus=n_gpus, loss_kwargs=config["loss"])

加载损失函数加载的为brats2020_config.json中"loss"部分参数,通过修改即可修改其损失函数。

构建优化器
optimizer = build_optimizer(optimizer_name=config[“optimizer”].pop(“name”), model_parameters=model.parameters(), **config[“optimizer”])

def build_optimizer(optimizer_name, model_parameters, **kwargs):
    return getattr(torch.optim, optimizer_name)(params=model_parameters, **kwargs)

加载优化器加载的为brats2020_config.json中"optimizer"部分参数,通过修改即可修改其优化器。

构建学习率调度器
scheduler = build_scheduler_from_config(config, optimizer)

def build_scheduler_from_config(config, optimizer):
    if "scheduler" not in config:
        scheduler = None
    else:
        scheduler_class = getattr(torch.optim.lr_scheduler, config["scheduler"].pop("name"))
        scheduler = scheduler_class(optimizer, **config["scheduler"])
    return scheduler

加载学习率调度器加载的为brats2020_config.json中"scheduler"部分参数,通过修改即可修改其学习率调度器。

接下来看一看7运行模型训练
其中run_training()函数在unet3d/train/train.py中,其注释如下:

def run_training(model, optimizer, criterion, n_epochs, training_loader, validation_loader, training_log_filename,
                 model_filename, metric_to_monitor="val_loss", early_stopping_patience=None,
                 save_best=False, n_gpus=1, save_every_n_epochs=None,
                 save_last_n_models=None, amp=False, scheduler=None,
                 samples_per_epoch=None):
    # 创建一个空列表,用于存储训练期间的日志信息
    training_log = list()
    # 检查指定的训练日志文件是否存在
    if os.path.exists(training_log_filename):
        # 如果日志文件存在,从文件中读取日志数据并将其添加到训练日志列表
        training_log.extend(pd.read_csv(training_log_filename).values)
        # 计算起始的训练轮次
        start_epoch = int(training_log[-1][0]) + 1
    else:
        # 如果日志文件不存在,将起始训练轮次设置为1
        start_epoch = 1
    # 定义训练日志的列标题
    training_log_header = ["epoch", "loss", "lr", "val_loss"]
    # 如果存在调度器并且起始训练轮次大于1,则需要为之前的轮次执行调度器和优化器的步进操作
    if scheduler is not None and start_epoch > 1:
        # step the scheduler and optimizer to account for previous epochs
        # 逐轮次执行调度器和优化器的步进操作,以考虑之前的轮次
        for i in range(1, start_epoch):
            optimizer.step()
            # 如果调度器是 ReduceLROnPlateau 类型,则考虑之前轮次的指标来更新学习率
            if scheduler.__class__ == torch.optim.lr_scheduler.ReduceLROnPlateau:
                metric = np.asarray(training_log)[i - 1, training_log_header.index(metric_to_monitor)]
                scheduler.step(metric)
            else:
                scheduler.step()
    # 如果启用混合精度训练,则创建一个梯度缩放器,否则设置为None
    if amp:
        from torch.cuda.amp import GradScaler
        scaler = GradScaler()
    else:
        scaler = None
    # 开始训练循环,逐轮次执行训练和验证
    for epoch in range(start_epoch, n_epochs+1):
        # early stopping
        # 执行早停操作,如果满足早停条件,则提前结束训练循环
        if training_log:
            metric = np.asarray(training_log)[:, training_log_header.index(metric_to_monitor)]
        if (training_log and early_stopping_patience
                and metric.argmin() <= len(training_log) - early_stopping_patience):
            print("Early stopping patience {} has been reached.".format(early_stopping_patience))
            break
        # 检查是否存在无效的结果,如果存在,则提前结束训练循环
        if training_log and np.isnan(metric[-1]):
            print("Stopping as invalid results were returned.")
            break

        # train the model
        # 执行模型训练
        loss = epoch_training(training_loader, model, criterion, optimizer=optimizer, epoch=epoch, n_gpus=n_gpus,
                              scaler=scaler, samples_per_epoch=samples_per_epoch)

        # Clear the cache from the GPUs
        # 清空GPU的缓存,以释放内存
        if n_gpus:
            torch.cuda.empty_cache()

        # predict validation data
        # 预测验证数据集
        if validation_loader:
            val_loss = epoch_validatation(validation_loader, model, criterion, n_gpus=n_gpus,
                                          use_amp=scaler is not None)
        else:
            val_loss = None

        # update the training log
        # 更新训练日志
        training_log.append([epoch, loss, get_lr(optimizer), val_loss])
        # 将训练日志保存到CSV文件中
        pd.DataFrame(training_log, columns=training_log_header).set_index("epoch").to_csv(training_log_filename)
        # 找到具有最小验证损失的轮次
        min_epoch = np.asarray(training_log)[:, training_log_header.index(metric_to_monitor)].argmin()

        # check loss and decay
        # 检查损失和学习率调度器
        if scheduler:
            if validation_loader and scheduler.__class__ == torch.optim.lr_scheduler.ReduceLROnPlateau:
                scheduler.step(val_loss)
            elif scheduler.__class__ == torch.optim.lr_scheduler.ReduceLROnPlateau:
                scheduler.step(loss)
            else:
                scheduler.step()

        # save model
        # 保存模型
        if n_gpus > 1:
            # 如果有多个GPU,则保存模型的状态字典
            torch.save(model.module.state_dict(), model_filename)
        else:
            # 如果只有一个GPU,则保存模型的状态字典
            torch.save(model.state_dict(), model_filename)
        # 如果设置了保存最佳模型,且当前轮次具有最小验证损失,则将模型复制为"best"版本
        if save_best and min_epoch == len(training_log) - 1:
            best_filename = append_to_filename(model_filename, "best")
            forced_copy(model_filename, best_filename)
        # 如果设置了每N轮保存模型,且当前轮次是指定的轮次,则将模型保存为单独的版本
        if save_every_n_epochs and (epoch % save_every_n_epochs) == 0:
            epoch_filename = append_to_filename(model_filename, epoch)
            forced_copy(model_filename, epoch_filename)
        # 如果设置了保存最后N个模型,且当前轮次不是保存轮次,则删除旧版本模型文件
        if save_last_n_models is not None and save_last_n_models > 1:
            if not save_every_n_epochs or not ((epoch - save_last_n_models) % save_every_n_epochs) == 0:
                to_delete = append_to_filename(model_filename, epoch - save_last_n_models)
                remove_file(to_delete)
            epoch_filename = append_to_filename(model_filename, epoch)
            forced_copy(model_filename, epoch_filename)

其中每个epoch_trainging()在unet3d/train/training_utils.py中,注释如下:

def epoch_training(train_loader, model, criterion, optimizer, epoch, n_gpus=None, print_frequency=1,
                   print_gpu_memory=False, scaler=None, samples_per_epoch=None):
    # 创建用于度量批处理时间、数据加载时间和损失的平均值的对象
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    # 创建用于显示训练进度的对象
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch))
    # 检查是否启用了混合精度训练(Automatic Mixed Precision, AMP)
    use_amp = scaler is not None

    # switch to train mode
    # 将模型切换到训练模式
    model.train()

    end = time.time()
    for i, item in enumerate(train_loader):
        images = item["image"]
        target = item["label"]
        # measure data loading time
        # 记录数据加载时间
        data_time.update(time.time() - end)
        # 如果有多个 GPU,则清空 GPU 缓存
        if n_gpus:
            torch.cuda.empty_cache()
            if print_gpu_memory:
                for i_gpu in range(n_gpus):
                    print("Memory allocated (device {}):".format(i_gpu),
                          human_readable_size(torch.cuda.memory_allocated(i_gpu)))
                    print("Max memory allocated (device {}):".format(i_gpu),
                          human_readable_size(torch.cuda.max_memory_allocated(i_gpu)))
                    print("Memory cached (device {}):".format(i_gpu),
                          human_readable_size(torch.cuda.memory_cached(i_gpu)))
                    print("Max memory cached (device {}):".format(i_gpu),
                          human_readable_size(torch.cuda.max_memory_cached(i_gpu)))
        # 优化器梯度清零
        optimizer.zero_grad()
        # 计算批次的损失值和批次大小
        loss, batch_size = batch_loss(model, images, target, criterion, n_gpus=n_gpus, use_amp=use_amp)

        # measure accuracy and record loss
        # 记录损失值
        losses.update(loss.item(), batch_size)
        # 如果启用混合精度训练,则使用 scaler 进行梯度缩放和反向传播
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # compute gradient and do step
            # 否则,直接进行反向传播和优化器步骤
            loss.backward()
            optimizer.step()
        # 释放 loss 变量的内存
        del loss

        # measure elapsed time
        # 记录批次时间
        batch_time.update(time.time() - end)
        end = time.time()
        # 每达到一定的打印频率,显示进度信息
        if i % print_frequency == 0:
            progress.display(i+1)
        # 如果达到指定的 samples_per_epoch 数量,则提前结束 epoch
        if samples_per_epoch and (i + 1) * batch_size >= samples_per_epoch:
            break
    # 返回平均损失
    return losses.avg

源码较为复杂,目前仅看懂了一部分,随时补充,欢迎大家一起交流,因为本人为初学者,所以有写的不准确或者不对的地方欢迎指正,我会在第一时间进行修改。

后续将在此基础上自己制作数据集并尝试修改代码训练。

第一篇bolg,希望大家觉得不错可以点赞支持谢谢!

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值