医学图像处理算法学习——3DUnetCNN(1)
第一章 3DUnetCNN(1)——原demo部署运行于源码分析
前言
作为初入医学图像处理领域的新手,本博客旨在记录自己对于3DUnetCNN学习过程中遇到的问题,以及自己对代码的理解与思考,包括后续如何将自己的数据集转换为该程序可用的数据集,供自己后续回看,也供大家一起学习参考。
一、源码与数据集
代码源码下载自GitHub:link
本代码数据集采用的是BraTS2020图像库,下载地址:link
需分别下载训练集与验证集,下载后将其按下述格式放置于examples文件夹中

二、运行训练demo
1.搭建环境
首先创建其自己的虚拟环境
安装安装所需库
代码如下(示例):
pip install -r 3DUnetCNN/requirements.txt
==注意:==如未安装torch,需先安装torch(GPU版本)
安装好后在cmd运行下述代码确认是否正确安装:
python
import torch
print(torch.cuda.is_available())
#返回True说明安装成功
2.运行代码
代码如下(示例):
cd examples/brats2020 #打开至数据集路径
#注意源代码一定要在该路径下运行,否则会找不到数据集报错
python /path/to/unet3d/scripts/train.py --config_filename brats2020_config.json
#/path/to/unet3d/scripts/train.py替换为scripts/train.py所在的绝对路径
注意:
1.该demo运行时遇到一些问题,因此对其中部分文件进行了微调,在后续会指出。
2.注意源代码一定要在examples/brats2020路径下运行,否则会找不到数据集报错。
3.BraTS2020数据集较大,若不修改直接运行demo,需要较大的GPU内存,否则会报错torch.cuda.OutOfMemoryError: CUDA out of memory.
三、代码分析
在运行成功后,本部分旨在对3DUnetCNN源码进行分析,了解其结构与每部分代码功能,方便后续修改与确定自己的数据集,在本段中为主要代码添加了注释。
1.代码结构

整体代码结构如上图所示,目前进了解了部分代码,其中examples/brats2020中存放的是数据集,brats2020_config.json为训练数据集的信息。

unet3d中为整个算法的主体部分,其中scripts/train.py与scripts/perdict.py为训练和预测的主函数,其他函数为其中各个单元。
2.train.py
注意: 由于train.py中对其他各个文件进行了引用与跳转,因此在下面讲解中将按照逻辑跳转不同文件进行分析,同时指出在运行demo时进行修改的部分。
首先是引用库,由于虽然unet3d文件夹中包含了__init__.py文件,但是仍报错找不到unet3d库,因此在程序中手动添加路径代码
sys.path.append('C:\\Users\\15496\\Desktop\\3DUnetCNN-master')
#路径为算法所在绝对路径

接下来是最下端的主函数部分,其重点内容为parse_args()与run(),parse_args()是参数定义函数。run()是训练算法,在main()中,首先定义参数,然后执行运行函数。后续将对这两个函数进行详细讲解。
首先是parse_args()函数:
def parse_args():
# 创建一个命令行参数解析器对象
parser = argparse.ArgumentParser()
# 添加一个命令行参数,指定配置文件的路径,参数名为--config_filename,必须提供该参数
parser.add_argument("--config_filename", required=True,
help="JSON configuration file specifying the parameters for model training.")
# 添加一个命令行参数,指定输出目录的路径,参数名为--output_dir,是可选参数
parser.add_argument("--output_dir", required=False,
help="Output directory where all the outputs will be saved. "
"Defaults to the directory of the configuration file.")
# 添加一个命令行参数,指定是否只生成交叉验证配置文件的标志,参数名为--setup_crossval_only
parser.add_argument("--setup_crossval_only", action="store_true", default=False,
help="Only write the cross-validation configuration files. "
"If selected, training will not be run. Instead the filenames will be split into "
"folds and modified configuration files will be written to the working directory. "
"This is useful if you want to submit training folds to an HPC scheduler system.")
# 添加一个命令行参数,指定预训练模型文件的路径,参数名为--pretrained_model_filename
parser.add_argument("--pretrained_model_filename",
help="If this filename exists prior to training, the model will be loaded from the filename. "
"Default is '{output_dir}/{config_basename}/model.pth'. "
"The default behavior is to use flexible loading of the model where not all the "
"model layers/weights have to match. "
"If training is interrupted and resumed, if a pretrained model is defined "
"the pretrained model will be used instead of loading "
"the model that was being trained initially. Therefore, if you are resuming training "
"it is best to not set the pretrained_model_filename.",
required=False)
# 添加一个命令行参数,指定训练日志文件的路径,参数名为--training_log_filename
parser.add_argument("--training_log_filename",
help="CSV filename to save the to save the training and validation results for each epoch. "
"Default is '{output_dir}/{config_basename}/training_log.csv'",
required=False)
# 添加一个命令行参数,覆盖配置文件中的批量大小,参数名为--batch_size
parser.add_argument("--batch_size", help="Override the batch size from the config file.", type=int)
# 添加一个命令行参数,指定是否开启调试模式的标志,参数名为--debug
parser.add_argument("--debug", action="store_true", default=False,
help="Raises an error if a training file is not found. The default is to silently skip"
"any training files that cannot be found. Use this flag to debug the config for finding"
"the data.")
# 调用一个自定义函数add_machine_config_to_parser,向命令行参数解析器添加机器配置相关的参数
add_machine_config_to_parser(parser)
# 添加一个命令行参数,指定用于调试目的的示例输入/输出对的数量,参数名为--n_examples
parser.add_argument("--n_examples", type=int, default=1,
help="Number of example input/output pairs to write to file for debugging purposes. "
"(default = 1)")
# 解析命令行参数,并将解析的结果存储在名为args的对象中
args = parser.parse_args()
# 返回解析的命令行参数对象
return args
parse_args()参数分析:
config_filename指配置文件所在路径;
output_dir指输出所在路径,如未特殊指定默认与config_filename相同文件夹;
set_crossval_only为生成交叉验证配置文件的标志,默认是False,当为True时,只生成配置文件,不进行训练,此参数在后续run()中会说明;
pretrained_model_filename指预训练模型文件的路径,默认路径为‘{output_dir}/{config_basename}/model.pth’,可以没有,比如在直接下载下来的代码中就不包含预训练模型;
training_log_filename是训练日志文件的路径,默认为’{output_dir}/{config_basename}/training_log.csv’;
batch_size是训练batch大小,在config文件中已经规定,可以进行修改;
debug是调试模式,默认为False,True时进入调试模式;
n_examples是用于调试的参数,表示用于调试目的的示例输入/输出对的数量。
该段代码中–config_filename为必输参数,因此需要在运行程序时输入路径:
python /path/to/unet3d/scripts/train.py --config_filename brats2020_config.json
考虑到每次都需要输入路径较为麻烦,因此对其进行修改:
parser.add_argument("--config_filename", required=False, default='C:\\Users\\15496\\Desktop\\3DUnetCNN-master\\examples\\brats2020\\brats2020_config.json',
help="JSON configuration file specifying the parameters for model training.")
其中将required修改为False,default后路径为brats2020_config.json所在绝对路径。
接下来说明一下brats2020_config.json文件:
config文件内包含了训练所需的所有参数:
“model”:
设定了输入输出、每一层的strides、filters以及卷积核的大小(333)、上采样核的大小(222)
“optimizer”:Adam
“loss”:DiceLoss
“cross_validation":交叉验证折数folds=5,随机数种子seed=25
“scheduler”:学习率设置
“dataset”::数据集设置
“training”:训练参数设置
以及训练集和测试集的名称
"scheduler": {
#学习率调度器
"name": "ReduceLROnPlateau",
#用于控制学习率降低的参数。它表示如果模型在验证集上的性能在连续10个训练周期(epoch)中都没有显著改善,那么学习率将被降低。这是为了应对训练过程中的性能停滞或收敛到局部最小值的情况。
"patience": 10,
#学习率降低的因子。一旦达到了"patience"指定的停滞期,学习率将乘以这个因子,从而降低学习率的大小。在这里,学习率将减小到原来的一半。
"factor": 0.5,
#学习率的最小值。学习率将不会降低到比这个值更小。这是为了确保学习率不会变得过小,防止训练过程过于缓慢。
"min_lr": 1e-08
"dataset": {
"name": "SegmentationDatasetPersistent",
#调整数据集内图片大小
"desired_shape": [
128,
128,
128
],
#指定数据集类别标签
"labels": [
2,
1,
4
],
#(这个目前不太清楚)可能表示是否要设置类别标签的层次结构。如果设置为true,那么类别标签可能被组织成一个层次结构,而不仅仅是简单的类别标签。
"setup_label_hierarchy": true,
#图像归一化
"normalization": "NormalizeIntensityD",
"normalization_kwargs": {
"channel_wise": true,
"nonzero": false
},
#是否对数据进行重新采样。
"resample": true,
#是否在预处理中对图像进行前景裁剪。前景裁剪是一种常见的图像分割预处理技术,用于移除图像边缘周围的背景部分,以减小输入图像的大小,从而加速训练。
"crop_foreground": true
"training": {
"batch_size": 1, # batch大小,表示逐个样本训练
"validation_batch_size": 1, #验证集batch大小
"amp": false, #是否启用混合精度训练,通过使用低位数的浮点数来加速模型训练。
#表示早停止(early stopping)的耐心度,即在多少个训练周期内没有性能改善时停止训练。null表示没有设置早停止,训练将继续执行指定的训练周期数("n_epochs")。
"early_stopping_patience": null,
#这个参数指定了模型训练的总训练周期数。模型将在数据集上进行250个训练周期,然后停止(除非启用了早停止)。
"n_epochs": 250,
#这个参数可能表示模型保存的频率。如果设置为null,表示不会在特定的训练周期上保存模型。如果设置为一个正整数值,比如10,那么每10个训练周期将保存一次模型。
"save_every_n_epochs": null,
#这个参数可能表示要保存的最近模型的数量。如果设置为null,表示只保存最终模型。如果设置为一个正整数值,比如5,那么将保存最近的5个模型。
"save_last_n_models": null,
#这个参数表示是否要保存性能最佳的模型。如果设置为true,系统将在验证集上跟踪模型性能,并保存在验证集上性能最佳的模型。
"save_best": true
最后是run()函数:
def run(config_filename, output_dir, namespace):
# 打印配置文件的路径
print("Config: ", config_filename)
# 加载配置文件中的 JSON 数据,并将其存储在名为 config 的变量中
config = load_json(config_filename)
# 从配置文件中加载文件名
load_filenames_from_config(config)
# 根据配置文件的基本名称创建工作目录,并确保该目录存在
work_dir = os.path.join(output_dir, os.path.basename(config_filename).split(".")[0])
print("Work Dir:", work_dir)
os.makedirs(work_dir, exist_ok=True)
# 如果配置中包含交叉验证信息
if "cross_validation" in config:
# call parent function through each fold of the training set
# 调用 setup_cross_validation 函数,为交叉验证设置配置
cross_validation_config = config.pop("cross_validation")
for _config, _config_filename in setup_cross_validation(config,
work_dir=work_dir,
n_folds=in_config("n_folds",
cross_validation_config,
5),
random_seed=in_config("random_seed",
cross_validation_config,
25)):
# 如果不是仅设置交叉验证而不运行训练,则递归运行 run 函数
if not namespace.setup_crossval_only:
print("Running cross validation fold:", _config_filename)
run(_config_filename, work_dir, namespace)
else:
# 否则,仅设置交叉验证
print("Setup cross validation fold:", _config_filename)
else:
# run the training
# 否则,执行模型训练
# 获取系统配置信息
system_config = get_machine_config(namespace)
# set verbosity
# 设置详细程度(verbosity),如果启用了调试模式
if namespace.debug:
if "dataset" not in config:
config["dataset"] = dict()
config["dataset"]["verbose"] = namespace.debug
warnings.filterwarnings('error')
# Override the batch size from the config file
# 覆盖配置文件中的批量大小(batch size)
if namespace.batch_size:
warnings.warn(RuntimeWarning('Overwriting the batch size from the configuration file (batch_size={}) to '
'batch_size={}'.format(config["training"]["batch_size"], namespace.batch_size)))
config["training"]["batch_size"] = namespace.batch_size
# 指定模型文件的路径
model_filename = os.path.join(work_dir, "model.pth")
print("Model: ", model_filename)
# 指定训练日志文件的路径
if namespace.training_log_filename:
training_log_filename = namespace.training_log_filename
else:
training_log_filename = os.path.join(work_dir, "training_log.csv")
print("Log: ", training_log_filename)
# 检查标签层次结构
label_hierarchy = check_hierarchy(config)
# 加载数据集类
dataset_class = load_dataset_class(config["dataset"], cache_dir=os.path.join(work_dir, "cache"))
# 构建训练和验证数据加载器,以及需要监测的指标
training_loader, validation_loader, metric_to_monitor = build_data_loaders_from_config(config,
system_config,
work_dir,
dataset_class)
# 指定预训练模型文件的路径
pretrained = namespace.pretrained_model_filename
if pretrained:
pretrained = os.path.abspath(pretrained)
else:
pretrained = model_filename
# 构建或加载模型
model = build_or_load_model_from_config(config,
pretrained,
system_config["n_gpus"])
# 加载损失函数
criterion = load_criterion_from_config(config, n_gpus=system_config["n_gpus"])
# 构建优化器
optimizer = build_optimizer(optimizer_name=config["optimizer"].pop("name"),
model_parameters=model.parameters(),
**config["optimizer"])
# 构建学习率调度器
scheduler = build_scheduler_from_config(config, optimizer)
# 运行模型训练
run_training(model=model.train(), optimizer=optimizer, criterion=criterion,
n_epochs=in_config("n_epochs", config["training"], 1000),
training_loader=training_loader, validation_loader=validation_loader,
model_filename=model_filename,
training_log_filename=training_log_filename,
metric_to_monitor=metric_to_monitor,
early_stopping_patience=in_config("early_stopping_patience", config["training"], None),
save_best=in_config("save_best", config["training"], True),
n_gpus=system_config["n_gpus"],
save_every_n_epochs=in_config("save_every_n_epochs", config["training"], None),
save_last_n_models=in_config("save_last_n_models", config["training"], None),
amp=in_config("amp", config["training"], None),
scheduler=scheduler,
samples_per_epoch=in_config("samples_per_epoch", config["training"], None))
# 为推断过程构建数据加载器,并将结果保存到相应的目录中
for _dataloader, _name in build_inference_loaders_from_config(config,
dataset_class=dataset_class,
system_config=system_config):
prediction_dir = os.path.join(work_dir, _name)
os.makedirs(prediction_dir, exist_ok=True)
volumetric_predictions(model=model,
dataloader=_dataloader,
prediction_dir=prediction_dir,
interpolation="trilinear",
resample=in_config("resample", config["dataset"], False))
run()函数的注释如上所示,整体流程是这样的:
1.首先加载配置文件,并创建所需工作目录,如‘{output_dir}/{config_basename}’;
# 打印配置文件的路径
print("Config: ", config_filename)
# 加载配置文件中的 JSON 数据,并将其存储在名为 config 的变量中
config = load_json(config_filename)
# 从配置文件中加载文件名
load_filenames_from_config(config)
# 根据配置文件的基本名称创建工作目录,并确保该目录存在
work_dir = os.path.join(output_dir, os.path.basename(config_filename).split(".")[0])
print("Work Dir:", work_dir)
os.makedirs(work_dir, exist_ok=True)
2.然后如果配置文件中存在交叉验证信息,说明程序采用交叉验证进行训练(由于医学数据集较小,一般都采用交叉验证来进行训练),此时如果setup_crossval_only为False则会训练,如果为True则只设置交叉验证,不训练;
# 如果配置中包含交叉验证信息
if "cross_validation" in config:
# call parent function through each fold of the training set
# 调用 setup_cross_validation 函数,为交叉验证设置配置
cross_validation_config = config.pop("cross_validation")
for _config, _config_filename in setup_cross_validation(config,
work_dir=work_dir,
n_folds=in_config("n_folds",
cross_validation_config,
5),
random_seed=in_config("random_seed",
cross_validation_config,
25)):
# 如果不是仅设置交叉验证而不运行训练,则递归运行 run 函数
if not namespace.setup_crossval_only:
print("Running cross validation fold:", _config_filename)
run(_config_filename, work_dir, namespace)
else:
# 否则,仅设置交叉验证
print("Setup cross validation fold:", _config_filename)
然后开始训练模型:
3.设置batch size
# Override the batch size from the config file
# 覆盖配置文件中的批量大小(batch size)
if namespace.batch_size:
warnings.warn(RuntimeWarning('Overwriting the batch size from the configuration file (batch_size={}) to '
'batch_size={}'.format(config["training"]["batch_size"], namespace.batch_size)))
config["training"]["batch_size"] = namespace.batch_size
4.指定模型文件与训练日志文件的路径
# 指定模型文件的路径
model_filename = os.path.join(work_dir, "model.pth")
print("Model: ", model_filename)
# 指定训练日志文件的路径
if namespace.training_log_filename:
training_log_filename = namespace.training_log_filename
else:
training_log_filename = os.path.join(work_dir, "training_log.csv")
print("Log: ", training_log_filename)
5.加载数据集
# 检查标签层次结构
label_hierarchy = check_hierarchy(config)
# 加载数据集类
dataset_class = load_dataset_class(config["dataset"], cache_dir=os.path.join(work_dir, "cache"))
# 构建训练和验证数据加载器,以及需要监测的指标
training_loader, validation_loader, metric_to_monitor = build_data_loaders_from_config(config,
system_config,
work_dir,
dataset_class)
6.构建各个参数
# 构建或加载模型
model = build_or_load_model_from_config(config,
pretrained,
system_config["n_gpus"])
# 加载损失函数
criterion = load_criterion_from_config(config, n_gpus=system_config["n_gpus"])
# 构建优化器
optimizer = build_optimizer(optimizer_name=config["optimizer"].pop("name"),
model_parameters=model.parameters(),
**config["optimizer"])
# 构建学习率调度器
scheduler = build_scheduler_from_config(config, optimizer)
7.运行模型训练
# 运行模型训练
run_training(model=model.train(), optimizer=optimizer, criterion=criterion,
n_epochs=in_config("n_epochs", config["training"], 1000),
training_loader=training_loader, validation_loader=validation_loader,
model_filename=model_filename,
training_log_filename=training_log_filename,
metric_to_monitor=metric_to_monitor,
early_stopping_patience=in_config("early_stopping_patience", config["training"], None),
save_best=in_config("save_best", config["training"], True),
n_gpus=system_config["n_gpus"],
save_every_n_epochs=in_config("save_every_n_epochs", config["training"], None),
save_last_n_models=in_config("save_last_n_models", config["training"], None),
amp=in_config("amp", config["training"], None),
scheduler=scheduler,
samples_per_epoch=in_config("samples_per_epoch", config["training"], None))
结合之前提到的brats2020_config.json讲解,接下来来看一看第6,7两点中的代码:
首先是6构建各个参数:该部分所引用的函数均在script_utils.py中:
构建或加载模型
model = build_or_load_model_from_config(config, pretrained,system_config[“n_gpus”])
def build_or_load_model_from_config(config, model_filename, n_gpus, strict=False):
return build_or_load_model(config["model"].pop("name"), model_filename, n_gpus=n_gpus, **config["model"],
strict=strict)
加载模型加载的为brats2020_config.json中"model"部分参数,通过修改即可修改其模型。
加载损失函数
criterion = load_criterion_from_config(config, n_gpus=system_config[“n_gpus”])
def load_criterion_from_config(config, n_gpus):
return load_criterion(config['loss'].pop("name"), n_gpus=n_gpus, loss_kwargs=config["loss"])
加载损失函数加载的为brats2020_config.json中"loss"部分参数,通过修改即可修改其损失函数。
构建优化器
optimizer = build_optimizer(optimizer_name=config[“optimizer”].pop(“name”), model_parameters=model.parameters(), **config[“optimizer”])
def build_optimizer(optimizer_name, model_parameters, **kwargs):
return getattr(torch.optim, optimizer_name)(params=model_parameters, **kwargs)
加载优化器加载的为brats2020_config.json中"optimizer"部分参数,通过修改即可修改其优化器。
构建学习率调度器
scheduler = build_scheduler_from_config(config, optimizer)
def build_scheduler_from_config(config, optimizer):
if "scheduler" not in config:
scheduler = None
else:
scheduler_class = getattr(torch.optim.lr_scheduler, config["scheduler"].pop("name"))
scheduler = scheduler_class(optimizer, **config["scheduler"])
return scheduler
加载学习率调度器加载的为brats2020_config.json中"scheduler"部分参数,通过修改即可修改其学习率调度器。
接下来看一看7运行模型训练
其中run_training()函数在unet3d/train/train.py中,其注释如下:
def run_training(model, optimizer, criterion, n_epochs, training_loader, validation_loader, training_log_filename,
model_filename, metric_to_monitor="val_loss", early_stopping_patience=None,
save_best=False, n_gpus=1, save_every_n_epochs=None,
save_last_n_models=None, amp=False, scheduler=None,
samples_per_epoch=None):
# 创建一个空列表,用于存储训练期间的日志信息
training_log = list()
# 检查指定的训练日志文件是否存在
if os.path.exists(training_log_filename):
# 如果日志文件存在,从文件中读取日志数据并将其添加到训练日志列表
training_log.extend(pd.read_csv(training_log_filename).values)
# 计算起始的训练轮次
start_epoch = int(training_log[-1][0]) + 1
else:
# 如果日志文件不存在,将起始训练轮次设置为1
start_epoch = 1
# 定义训练日志的列标题
training_log_header = ["epoch", "loss", "lr", "val_loss"]
# 如果存在调度器并且起始训练轮次大于1,则需要为之前的轮次执行调度器和优化器的步进操作
if scheduler is not None and start_epoch > 1:
# step the scheduler and optimizer to account for previous epochs
# 逐轮次执行调度器和优化器的步进操作,以考虑之前的轮次
for i in range(1, start_epoch):
optimizer.step()
# 如果调度器是 ReduceLROnPlateau 类型,则考虑之前轮次的指标来更新学习率
if scheduler.__class__ == torch.optim.lr_scheduler.ReduceLROnPlateau:
metric = np.asarray(training_log)[i - 1, training_log_header.index(metric_to_monitor)]
scheduler.step(metric)
else:
scheduler.step()
# 如果启用混合精度训练,则创建一个梯度缩放器,否则设置为None
if amp:
from torch.cuda.amp import GradScaler
scaler = GradScaler()
else:
scaler = None
# 开始训练循环,逐轮次执行训练和验证
for epoch in range(start_epoch, n_epochs+1):
# early stopping
# 执行早停操作,如果满足早停条件,则提前结束训练循环
if training_log:
metric = np.asarray(training_log)[:, training_log_header.index(metric_to_monitor)]
if (training_log and early_stopping_patience
and metric.argmin() <= len(training_log) - early_stopping_patience):
print("Early stopping patience {} has been reached.".format(early_stopping_patience))
break
# 检查是否存在无效的结果,如果存在,则提前结束训练循环
if training_log and np.isnan(metric[-1]):
print("Stopping as invalid results were returned.")
break
# train the model
# 执行模型训练
loss = epoch_training(training_loader, model, criterion, optimizer=optimizer, epoch=epoch, n_gpus=n_gpus,
scaler=scaler, samples_per_epoch=samples_per_epoch)
# Clear the cache from the GPUs
# 清空GPU的缓存,以释放内存
if n_gpus:
torch.cuda.empty_cache()
# predict validation data
# 预测验证数据集
if validation_loader:
val_loss = epoch_validatation(validation_loader, model, criterion, n_gpus=n_gpus,
use_amp=scaler is not None)
else:
val_loss = None
# update the training log
# 更新训练日志
training_log.append([epoch, loss, get_lr(optimizer), val_loss])
# 将训练日志保存到CSV文件中
pd.DataFrame(training_log, columns=training_log_header).set_index("epoch").to_csv(training_log_filename)
# 找到具有最小验证损失的轮次
min_epoch = np.asarray(training_log)[:, training_log_header.index(metric_to_monitor)].argmin()
# check loss and decay
# 检查损失和学习率调度器
if scheduler:
if validation_loader and scheduler.__class__ == torch.optim.lr_scheduler.ReduceLROnPlateau:
scheduler.step(val_loss)
elif scheduler.__class__ == torch.optim.lr_scheduler.ReduceLROnPlateau:
scheduler.step(loss)
else:
scheduler.step()
# save model
# 保存模型
if n_gpus > 1:
# 如果有多个GPU,则保存模型的状态字典
torch.save(model.module.state_dict(), model_filename)
else:
# 如果只有一个GPU,则保存模型的状态字典
torch.save(model.state_dict(), model_filename)
# 如果设置了保存最佳模型,且当前轮次具有最小验证损失,则将模型复制为"best"版本
if save_best and min_epoch == len(training_log) - 1:
best_filename = append_to_filename(model_filename, "best")
forced_copy(model_filename, best_filename)
# 如果设置了每N轮保存模型,且当前轮次是指定的轮次,则将模型保存为单独的版本
if save_every_n_epochs and (epoch % save_every_n_epochs) == 0:
epoch_filename = append_to_filename(model_filename, epoch)
forced_copy(model_filename, epoch_filename)
# 如果设置了保存最后N个模型,且当前轮次不是保存轮次,则删除旧版本模型文件
if save_last_n_models is not None and save_last_n_models > 1:
if not save_every_n_epochs or not ((epoch - save_last_n_models) % save_every_n_epochs) == 0:
to_delete = append_to_filename(model_filename, epoch - save_last_n_models)
remove_file(to_delete)
epoch_filename = append_to_filename(model_filename, epoch)
forced_copy(model_filename, epoch_filename)
其中每个epoch_trainging()在unet3d/train/training_utils.py中,注释如下:
def epoch_training(train_loader, model, criterion, optimizer, epoch, n_gpus=None, print_frequency=1,
print_gpu_memory=False, scaler=None, samples_per_epoch=None):
# 创建用于度量批处理时间、数据加载时间和损失的平均值的对象
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
# 创建用于显示训练进度的对象
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses],
prefix="Epoch: [{}]".format(epoch))
# 检查是否启用了混合精度训练(Automatic Mixed Precision, AMP)
use_amp = scaler is not None
# switch to train mode
# 将模型切换到训练模式
model.train()
end = time.time()
for i, item in enumerate(train_loader):
images = item["image"]
target = item["label"]
# measure data loading time
# 记录数据加载时间
data_time.update(time.time() - end)
# 如果有多个 GPU,则清空 GPU 缓存
if n_gpus:
torch.cuda.empty_cache()
if print_gpu_memory:
for i_gpu in range(n_gpus):
print("Memory allocated (device {}):".format(i_gpu),
human_readable_size(torch.cuda.memory_allocated(i_gpu)))
print("Max memory allocated (device {}):".format(i_gpu),
human_readable_size(torch.cuda.max_memory_allocated(i_gpu)))
print("Memory cached (device {}):".format(i_gpu),
human_readable_size(torch.cuda.memory_cached(i_gpu)))
print("Max memory cached (device {}):".format(i_gpu),
human_readable_size(torch.cuda.max_memory_cached(i_gpu)))
# 优化器梯度清零
optimizer.zero_grad()
# 计算批次的损失值和批次大小
loss, batch_size = batch_loss(model, images, target, criterion, n_gpus=n_gpus, use_amp=use_amp)
# measure accuracy and record loss
# 记录损失值
losses.update(loss.item(), batch_size)
# 如果启用混合精度训练,则使用 scaler 进行梯度缩放和反向传播
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
# compute gradient and do step
# 否则,直接进行反向传播和优化器步骤
loss.backward()
optimizer.step()
# 释放 loss 变量的内存
del loss
# measure elapsed time
# 记录批次时间
batch_time.update(time.time() - end)
end = time.time()
# 每达到一定的打印频率,显示进度信息
if i % print_frequency == 0:
progress.display(i+1)
# 如果达到指定的 samples_per_epoch 数量,则提前结束 epoch
if samples_per_epoch and (i + 1) * batch_size >= samples_per_epoch:
break
# 返回平均损失
return losses.avg
源码较为复杂,目前仅看懂了一部分,随时补充,欢迎大家一起交流,因为本人为初学者,所以有写的不准确或者不对的地方欢迎指正,我会在第一时间进行修改。
后续将在此基础上自己制作数据集并尝试修改代码训练。
第一篇bolg,希望大家觉得不错可以点赞支持谢谢!
458

被折叠的 条评论
为什么被折叠?



