【深度学习模型代码结构】第1篇——main.py(上)

        #声明,由于本人为新手小白,因此本人的代码是在其他大佬的代码基础上进行修改并自己使用的,并不是完全自己编写,因此如果不慎涉及侵权请联系我进行删除谢谢。然后如果觉得有帮助希望大家点赞鼓励谢谢。       

        本篇博客用于记录解释我的main.py函数,该函数为平时训练时所需要运行的函数(当然其他模块均必备后),在terminal中输入python main.py即可开始网络训练。话不多说,开始代码部分:

       首先是函数库的声明,我喜欢将其分为两个部分,第一个部分是python中大佬们已经打包好的常用库,第二部分是我自己写的其他函数,在此不做过多展开,在后面的博客中将会对每个代码进行详细说明,确保整个代码是可运行的。

from medicaltorch import datasets as mt_datasets
import math
import numpy as np
import os
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau #当某指标不再变化(下降或升高),调整学习率
from torch.nn import BCELoss #Cross Entropy
from torch.nn.parallel import DataParallel #并行计算
from torch.nn.utils import clip_grad_norm_#梯度剪裁


from dataload.dataset_u3d import MRIDataset
from main import runs
from model.Unet3D import UNet3D
from utils.cumulative_average import CumulativeAverager
from utils import parse
from utils import optim
from utils import save

 首先是运行部分,其中root为数据保存的地址:

if __name__ == '__main__':
    root=r'*该main.py文件所在的路径*\\datasets'   #服务器数据地址

    main()

下面来说明一下main函数,首先将其全部贴出:

def main():
    log_str = ''
    # 初始参数设置
    args = parse.parse_opt()
    args.cuda = torch.cuda.is_available()
    # 加载数据集
    if not args.demo:
        primary_dataset = get_model(args, mode='train')
        val_dataset =  get_model(args, mode='val')
        primary_data_loader = DataLoader(primary_dataset, args.batch_size, shuffle=True,num_workers=0, collate_fn=mt_datasets.mt_collate)
        val_data_loader = DataLoader(val_dataset, args.batch_size, shuffle=True,num_workers=0, collate_fn=mt_datasets.mt_collate)
        # save.add_to_log(log_str,"training on %d samples"%len(primary_dataset))
    else:
        test_dataset = get_model(args, mode='test')
        test_data_loader = DataLoader(test_dataset, shuffle=False, batch_size=1)

    # 初始化模型
    device = torch.device("cuda:0")
    net = UNet3D().to(device)
    net.train()

    net = DataParallel(net,device_ids=[0])
    bce_criterion = BCELoss()

    optimizer = optim.get_optimizer(net, args.optimizer, args.lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2)#当某指标不再变化调整学习率

    losstype = [0,1,2,3,4,5], 'bce'

    # # 多GPU并行
    # gpu_ids = range(args.ngpu)
    # model = nn.parallel.DataParallel(model, device_ids=gpu_ids)

    if not args.demo:

        avg_tool = CumulativeAverager()

        vloss, is_best = torch.tensor(float(np.inf)), None
        # 检查是否从头开始训练
        if args.load_from is not None:
            if os.path.isfile(args.load_from):
                log_str = save.add_to_log(log_str, "=> loading checkpoint '{}'".format(args.load_from))
                checkpoint = torch.load(args.load_from)
                start = checkpoint['epoch']
                vloss = checkpoint['best_val_loss']
                net.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                args.lr = checkpoint['learning_rate']
                log_str = save.add_to_log(log_str, "=> loaded checkpoint '{}' (epoch {})"
                    .format(args.load_from, checkpoint['epoch']))
            else:
                log_str = save.add_to_log(log_str, "=> no checkpoint found at '{}'".format(args.load_from))
        else:
            start = 0
            
        # 训练
        train_loss_path = os.path.join('runs', 'train_loss.txt')
        val_loss_path = os.path.join('runs', 'val_loss.txt')
        os.makedirs('runs', exist_ok=True)
        for epoch in range(start, args.epochs):

            train_loss = runs.train(net, args, primary_data_loader, optimizer, bce_criterion, avg_tool, epoch, train_loss_path, losstype=losstype)
            val_loss = runs.validate(net, args, val_data_loader, bce_criterion, log_str, losstype=losstype).cpu()
            scheduler.step(val_loss)

            # 保存训练和验证损失
            with open(train_loss_path, 'a') as f_train_loss:
                f_train_loss.write(f"Epoch {epoch+1}, Train Loss={train_loss:.6f}\n")
            with open(val_loss_path, 'a') as f_val_loss:
                f_val_loss.write(f"Epoch {epoch+1}, val Loss={val_loss:.6f}\n")

            if vloss > val_loss:
                vloss = val_loss
                best_model = {
                    'epoch': epoch,
                    'state_dict': net.state_dict(),
                    'best_val_loss': vloss,
                    'optimizer': optimizer.state_dict(),
                    'learning_rate': args.lr
                }
                torch.save(best_model, args.best_path)
                print(f'Best model saved. It is the {epoch+1} epoch.')

            for param_group in optimizer.param_groups:
                lr = param_group['lr']
                
            if epoch == args.epochs-1:
                final_model = {
                    'epoch': epoch,
                    'state_dict': net.state_dict(),
                    'best_val_loss': val_loss,
                    'optimizer': optimizer.state_dict(),
                    'learning_rate': args.lr
                }

        torch.save(best_model, args.best_path)
        torch.save(final_model, args.final_path)

    else:
        runs.test(net, args.test_path, test_data_loader, args.test_save)

接下来逐部分进行解释,首先是初始参数设置:

    # 初始参数设置
    args = parse.parse_opt()
    args.cuda = torch.cuda.is_available()

初始参数设置通过utils模块下的parse.py中的parse_opt()函数实现,该模块参考yolo v5的数据加载形式,通过该形式,可以对模型训练过程中使用的参数进行统一的调整,使用起来较为方便。

其中parse_opt()包含的参数如下,我将其分为基础训练参数(如epoch、batchsize、learningrate等);高级训练参数(如是否中断训练,是否继续训练);储存参数(保存的路径);test参数(测试时的读取与保存路径等)...待补充

import argparse

def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    # 基础训练参数
    parser.add_argument('--batch_size', type=int, help='Number of 3D voxel batches',default=16)
    parser.add_argument('--lr', type=float, help='Initial Learning rate', default=0.001)
    parser.add_argument('--lr_decay', type=float, help='Learning Rate Decay', default=0.1)
    parser.add_argument('--optimizer', help='sgd, adam', default='adam')
    parser.add_argument('--epochs', help='Total number of epochs to train on data', default=100, type=int)

    # 高级训练参数
    parser.add_argument('--iters', help='Number of training batches per epoch', default=None, type=int)
    parser.add_argument('--aug', action='store_true', help='Flag to decide about input augmentations',default=False)
    parser.add_argument('--demo', action='store_true', help='Flag to indicate testing',default=False)
    parser.add_argument('--load_from', help='Path to checkpoint dict', default=None)
    parser.add_argument('--ngpu', type=int, default=1)
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',help='manual epoch number (useful on restarts)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',help='evaluate model on validation set')
    parser.add_argument('-i', '--inference', default='', type=str, metavar='PATH',help='run inference on data set and save results')

    # 储存参数
    parser.add_argument('--best_path', default='runs/best_model.pth', type=str)
    parser.add_argument('--final_path', default='runs/final_model.pth', type=str)

    # test参数
    parser.add_argument('--test_path', default='runs\\best_model.pth', type=str)
    parser.add_argument('--test_data', default='data/test', type=str)
    parser.add_argument('--test_save', default='runs/test_results_128_batch16', type=str)
    # parser.add_argument('--test_batch_size', default=1, type=int)

    return parser.parse_known_args()[0] if known else parser.parse_args()

初始参数设置好后,第一步要进行的操作是加载数据集:

    # 加载数据集
    if not args.demo:
        primary_dataset = get_model(args, mode='train')
        val_dataset =  get_model(args, mode='val')
        primary_data_loader = DataLoader(primary_dataset, args.batch_size, shuffle=True,num_workers=0, collate_fn=mt_datasets.mt_collate)
        val_data_loader = DataLoader(val_dataset, args.batch_size, shuffle=True,num_workers=0, collate_fn=mt_datasets.mt_collate)
        # save.add_to_log(log_str,"training on %d samples"%len(primary_dataset))
    else:
        test_dataset = get_model(args, mode='test')
        test_data_loader = DataLoader(test_dataset, shuffle=False, batch_size=1)

该部分代码的功能是加载数据集,其中if是用于判断当前为训练还是测试,根据parse中的demo参数可知,当args.demo为False时,为训练模式,因此加载的数据集为训练集和验证集,当demo为True时,此时为测试模式,加载的数据集为测试集。

如果想进入测试模式,则可以在parse中修改demo为True,或者在terminal中运行时输入python main.py --demo

该部分代码涉及到了get_model()和dataloader()两个函数,其中dataloader()是torch库中的数据集加载器,我个人的理解该函数的返回值是一组batch中用于训练的数据的索引,通过该索引即可在每个循环中调用数据集中数据用于训练。

逐个解释参数的含义(来自chatgpt):

  1. primary_dataset: 这是用于训练的主要数据集。在这里,它被传递给DataLoader以便加载其中的数据。

  2. args.batch_size: 这是一个参数,表示每个批次(batch)中包含的样本数。批量处理可以帮助提高训练效率并且可以利用并行计算。通常,它是一个超参数,可以在训练过程中进行调整以优化性能。

  3. shuffle=True: 这个参数表示是否在每个epoch之前对数据进行洗牌(随机重排)。洗牌可以帮助模型更好地学习,因为它可以减少每个epoch中样本的顺序相关性,从而增加模型的泛化能力。

  4. num_workers=0: 这是指定用于数据加载的子进程数量。当num_workers大于0时,数据加载器会使用多个子进程来预先加载数据,从而加速训练过程。但是在某些环境下(例如Windows),使用多个子进程可能会导致问题,因此将其设为0意味着只使用主进程加载数据。

  5. collate_fn=mt_datasets.mt_collate: 这是一个函数,用于自定义数据加载器如何将样本组合成批次。通常,它用于处理样本长度不一致的情况,比如在自然语言处理任务中,不同句子的长度可能不同。在这里,mt_collate是一个自定义的函数,可能会执行填充(padding)或截断(truncation)等操作,以确保每个批次中的样本具有相同的长度。

get_model()函数是一个自己编写的函数,其功能为加载数据,并提供一些加载数据的参数

'''获取数据集'''
#如果显存溢出,就要resize一下输入的图片尺寸
def get_model(args, mode, flag_3d = True, channel_size_3d = 32, mri_slice_dim = 128):

    assert math.log(mri_slice_dim, 2).is_integer() # Image dims must be powers of 2

    #是否数据增广
    if mode == 'train' and args.aug:
        aug = True
    else:
        aug = False

    t1_lgg = MRIDataset(root=root,  mode=mode, channel_size_3d=channel_size_3d, flag_3d=flag_3d, mri_slice_dim=mri_slice_dim, aug=aug)
    dataset = t1_lgg

    return dataset

在该段代码中,get_model定义了加载数据的维度,由于我训练过程中使用的是三维数据,因此其目前代码中的维度为128*128*32,如果显存大小存在差异,以及batchsize需要调整,可以同时修改数据的尺寸以满足要求。

assert为确保输入的数据mri_slice_dim为2的幂次方,以确保能正确输入后续的模型中;

aug为判断是否需要做数据增强,训练时进行数据增强,验证和测试时不做数据增强;

MRIDataset为自己的数据集加载类,将在后续进行详细讲解。

加载完数据后,接下来时初始化模型操作:

    # 初始化模型
    device = torch.device("cuda:0")
    net = UNet3D().to(device)
    net.train()

    net = DataParallel(net,device_ids=[0])
    bce_criterion = BCELoss()

    optimizer = optim.get_optimizer(net, args.optimizer, args.lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2)#当某指标不再变化调整学习率

    losstype = [0,1,2,3,4,5], 'bce'

device = torch.device("cuda:0"):选择训练的显卡,如果是笔记本电脑即0即可,如果使用服务器存在多个显卡,可以指定;

net = UNet3D().to(device):加载网络,我目前使用的是UNet3D,将在后面进行详细介绍;

net = DataParallel(net,device_ids=[0]):使用 DataParallel 将模型包装在多个 GPU 上,以实现并行化训练。device_ids=[0] 指定了使用的 GPU 设备的编号。

bce_criterion = BCELoss():创建一个二分类交叉熵损失函数对象,用于计算模型预测与真实标签之间的损失。但是在后面目前用的dice_loss,并未使用该loss。

optimizer = optim.get_optimizer(net, args.optimizer, args.lr):设置优化器,用于更新模型的参数。args.optimizer 和 args.lr 是从命令行参数中获取的优化器类型和学习率。其中get_optimizer函数将下面介绍。

scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2):使用ReduceLROnPlateau 调度器对学习率进行动态调整。它会在损失值不再减少时,降低学习率,有助于在训练后期更细致地调整学习率以获得更好的性能。'min' 表示在监视的指标(这里是损失)最小化时降低学习率。

losstype = [0,1,2,3,4,5], 'bce':定义损失类型。该行代码后续暂未用到。

from torch.optim import SGD, Adam

def get_optimizer(net, st, lr, momentum=0.9):
    if st == 'sgd':
        return SGD(net.parameters(), lr = lr, momentum=momentum)
    elif st == 'adam':
        return Adam(net.parameters(), lr = lr)

get_optimizer函数,根据parse参数来确定模型使用什么优化器,目前args.optimizer = adam,即使用的是adam优化器。

初步先更新这么多,后续代码将开始训练过程,将在下一篇博客中进行更新,整体代码由于参考了其他大佬的代码,因此其中还存在一些冗余的代码,由于现在代码可以运行,因此还未完全清除,后续将不断更新。

  • 29
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值