windows下调整STDC的训练代码为单机单卡训练

问题描述:

STDC原本是在Linux系统下使用多卡的分布式训练,这里在windows下调整了单机单卡的训练方式,记录一下

解决方案:

思路就是将源代码中的有关分布式训练的代码全删除掉就行,主要有dist.init_process_group()函数,dist.get_rank(),sampler = torch.utils.data.distributed.DistributedSampler(ds) ,net = nn.parallel.DistributedDataParallel,
然后再调整一些参数,如优化opt里将model = net.module改为model=net和参数local_rank

parse.add_argument(
            '--local_rank',
            dest = 'local_rank',
            type = int,
            default = 0,
            )

此外,如果出现报错:RuntimeError: Some elements marked as dirty during the forward method were not returned as output. The inputs that are modified inplace must all be outputs of the Function。就把model_stages里的BatchNorm2d从InPlaceABNSync全部换成torch自带的nn.BatchNorm2d。

整体个train.ty的修改如下:

from logger import setup_logger
from models.model_stages import BiSeNet
from cityscapes import CityScapes
from loss.loss import OhemCELoss
from loss.detail_loss import DetailAggregateLoss
from evaluation import MscEvalV0
from optimizer_loss import Optimizer

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.distributed as dist

import os
import os.path as osp
import logging
import time
import datetime
import argparse
logger = logging.getLogger()
def str2bool(v):

    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Unsupported value encountered.')


def parse_args():
    parse = argparse.ArgumentParser()
    parse.add_argument(
            '--local_rank',
            dest = 'local_rank',
            type = int,
            default = 0,
            )
    parse.add_argument(
            '--n_workers_train',
            dest = 'n_workers_train',
            type = int,
            default = 4,
            )
    parse.add_argument(
            '--n_workers_val',
            dest = 'n_workers_val',
            type = int,
            default = 0,
            )
    parse.add_argument(
            '--n_img_per_gpu',
            dest = 'n_img_per_gpu',
            type = int,
            default = 2,
            )
    parse.add_argument(
            '--max_iter',
            dest = 'max_iter',
            type = int,
            default = 40000,
            )
    parse.add_argument(
            '--save_iter_sep',
            dest = 'save_iter_sep',
            type = int,
            default = 1000,
            )
    parse.add_argument(
            '--warmup_steps',
            dest = 'warmup_steps',
            type = int,
            default = 1000,
            )      
    parse.add_argument(
            '--mode',
            dest = 'mode',
            type = str,
            default = 'train',
            )
    parse.add_argument(
            '--ckpt',
            dest = 'ckpt',
            type = str,
            default = None,
            )
    parse.add_argument(
            '--respath',
            dest = 'respath',
            type = str,
            default = None,
            )
    parse.add_argument(
            '--backbone',
            dest = 'backbone',
            type = str,
            default = 'CatNetSmall',
            )
    parse.add_argument(
            '--pretrain_path',
            dest = 'pretrain_path',
            type = str,
            default = '',
            )
    parse.add_argument(
            '--use_conv_last',
            dest = 'use_conv_last',
            type = str2bool,
            default = False,
            )
    parse.add_argument(
            '--use_boundary_2',
            dest = 'use_boundary_2',
            type = str2bool,
            default = False,
            )
    parse.add_argument(
            '--use_boundary_4',
            dest = 'use_boundary_4',
            type = str2bool,
            default = False,
            )
    parse.add_argument(
            '--use_boundary_8',
            dest = 'use_boundary_8',
            type = str2bool,
            default = False,
            )
    parse.add_argument(
            '--use_boundary_16',
            dest = 'use_boundary_16',
            type = str2bool,
            default = False,
            )
    return parse.parse_args()


def train():
    args = parse_args()
    
    save_pth_path = os.path.join(args.respath, 'pths')
    dspth = './data'
    
    # print(save_pth_path)
    # print(osp.exists(save_pth_path))
    # if not osp.exists(save_pth_path) and dist.get_rank()==0: 
    if not osp.exists(save_pth_path):
        os.makedirs(save_pth_path)
    


    torch.cuda.set_device(args.local_rank)
    # dist.init_process_group(  
    #             backend = 'nccl',  
    #             init_method = 'tcp://127.0.0.1:33274', 
    #             world_size = torch.cuda.device_count(),
    #             rank=args.local_rank          )
    
    setup_logger(args.respath)
    ## dataset
    n_classes = 19
    n_img_per_gpu = args.n_img_per_gpu
    n_workers_train = args.n_workers_train
    n_workers_val = args.n_workers_val
    use_boundary_16 = args.use_boundary_16
    use_boundary_8 = args.use_boundary_8
    use_boundary_4 = args.use_boundary_4
    use_boundary_2 = args.use_boundary_2
    
    mode = args.mode
    #cropsize = [1024, 512]
    cropsize = [64, 64]
    randomscale = (0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)

    logger.info('n_workers_train: {}'.format(n_workers_train))
    logger.info('n_workers_val: {}'.format(n_workers_val))
    logger.info('use_boundary_2: {}'.format(use_boundary_2))
    logger.info('use_boundary_4: {}'.format(use_boundary_4))
    logger.info('use_boundary_8: {}'.format(use_boundary_8))
    logger.info('use_boundary_16: {}'.format(use_boundary_16))
    logger.info('mode: {}'.format(args.mode))

    # if dist.get_rank()==0:
    #     logger.info('n_workers_train: {}'.format(n_workers_train))
    #     logger.info('n_workers_val: {}'.format(n_workers_val))
    #     logger.info('use_boundary_2: {}'.format(use_boundary_2))
    #     logger.info('use_boundary_4: {}'.format(use_boundary_4))
    #     logger.info('use_boundary_8: {}'.format(use_boundary_8))
    #     logger.info('use_boundary_16: {}'.format(use_boundary_16))
    #     logger.info('mode: {}'.format(args.mode))
    
    
    ds = CityScapes(dspth, cropsize=cropsize, mode=mode, randomscale=randomscale)
    # sampler = torch.utils.data.distributed.DistributedSampler(ds) # 分布式数据采样器
    dl = DataLoader(ds,
                    batch_size = n_img_per_gpu,
                    shuffle = False,
                  #  sampler = sampler,
                    num_workers = n_workers_train,
                    pin_memory = False,
                    drop_last = True)
    # exit(0)
    dsval = CityScapes(dspth, mode='val', randomscale=randomscale)
    # sampler_val = torch.utils.data.distributed.DistributedSampler(dsval)
    dlval = DataLoader(dsval,
                    batch_size = 2,
                    shuffle = False,
                   # sampler = sampler_val,
                    num_workers = n_workers_val,
                    drop_last = False)

    ## model
    ignore_idx = 255
    net = BiSeNet(backbone=args.backbone, n_classes=n_classes, pretrain_model=args.pretrain_path, 
    use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4, use_boundary_8=use_boundary_8, 
    use_boundary_16=use_boundary_16, use_conv_last=args.use_conv_last)

  if not args.ckpt is None:
        net.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    net.cuda()
    net.train()
    # net = nn.parallel.DistributedDataParallel(net,  
    #         device_ids = [args.local_rank, ], 
    #         output_device = args.local_rank, 
    #         find_unused_parameters=True       )

    score_thres = 0.7
    n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//16

    #交叉熵损失和细节损失
    criteria_p = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    criteria_16 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    criteria_32 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    boundary_loss_func = DetailAggregateLoss()

    ## optimizer
    maxmIOU50 = 0.
    maxmIOU75 = 0.
    momentum = 0.9
    weight_decay = 5e-4
    lr_start = 1e-2
    max_iter = args.max_iter
    save_iter_sep = args.save_iter_sep
    power = 0.9
    warmup_steps = args.warmup_steps
    warmup_start_lr = 1e-5

    # if dist.get_rank()==0:
    #     print('max_iter: ', max_iter)
    #     print('save_iter_sep: ', save_iter_sep)
    #     print('warmup_steps: ', warmup_steps)
    print('max_iter: ', max_iter)
    print('save_iter_sep: ', save_iter_sep)
    print('warmup_steps: ', warmup_steps)
    optim = Optimizer(
            #model = net.module,
            model=net,
            loss = boundary_loss_func,
            lr0 = lr_start,
            momentum = momentum,
            wd = weight_decay,
            warmup_steps = warmup_steps,
            warmup_start_lr = warmup_start_lr,
            max_iter = max_iter,
            power = power)
    
    ## train loop
    msg_iter = 50
    loss_avg = []
    loss_boundery_bce = []
    loss_boundery_dice = []
    st = glob_st = time.time()
    diter = iter(dl)
    epoch = 0
    for it in range(max_iter):
        try:
            im, lb = next(diter)   
            if not im.size()[0]==n_img_per_gpu: raise StopIteration
        except StopIteration:
            epoch += 1
           # sampler.set_epoch(epoch)
            diter = iter(dl)
            im, lb = next(diter)
        im = im.cuda()
        lb = lb.cuda()
        H, W = im.size()[2:]
        lb = torch.squeeze(lb, 1)

        optim.zero_grad()


        if use_boundary_2 and use_boundary_4 and use_boundary_8:
            out, out16, out32, detail2, detail4, detail8 = net(im)
        
        if (not use_boundary_2) and use_boundary_4 and use_boundary_8:
            out, out16, out32, detail4, detail8 = net(im)

        if (not use_boundary_2) and (not use_boundary_4) and use_boundary_8:
            out, out16, out32, detail8 = net(im)

        if (not use_boundary_2) and (not use_boundary_4) and (not use_boundary_8):
            out, out16, out32 = net(im)

        lossp = criteria_p(out, lb)
        loss2 = criteria_16(out16, lb)
        loss3 = criteria_32(out32, lb)
        
        boundery_bce_loss = 0.
        boundery_dice_loss = 0.
        
        
        if use_boundary_2: 
            # if dist.get_rank()==0:
            #     print('use_boundary_2')
            boundery_bce_loss2,  boundery_dice_loss2 = boundary_loss_func(detail2, lb)
            boundery_bce_loss += boundery_bce_loss2
            boundery_dice_loss += boundery_dice_loss2
        
        if use_boundary_4:
            # if dist.get_rank()==0:
            #     print('use_boundary_4')
            boundery_bce_loss4,  boundery_dice_loss4 = boundary_loss_func(detail4, lb)
            boundery_bce_loss += boundery_bce_loss4
            boundery_dice_loss += boundery_dice_loss4

        if use_boundary_8:
            # if dist.get_rank()==0:
            #     print('use_boundary_8')
            boundery_bce_loss8,  boundery_dice_loss8 = boundary_loss_func(detail8, lb)
            boundery_bce_loss += boundery_bce_loss8
            boundery_dice_loss += boundery_dice_loss8

        loss = lossp + loss2 + loss3 + boundery_bce_loss + boundery_dice_loss
        
        loss.backward()
        optim.step()

        loss_avg.append(loss.item())

        loss_boundery_bce.append(boundery_bce_loss.item())
        loss_boundery_dice.append(boundery_dice_loss.item())

        ## print training log message
        if (it+1)%msg_iter==0:
            loss_avg = sum(loss_avg) / len(loss_avg)
            lr = optim.lr
            ed = time.time()
            t_intv, glob_t_intv = ed - st, ed - glob_st
            eta = int((max_iter - it) * (glob_t_intv / it))
            eta = str(datetime.timedelta(seconds=eta))

            loss_boundery_bce_avg = sum(loss_boundery_bce) / len(loss_boundery_bce)
            loss_boundery_dice_avg = sum(loss_boundery_dice) / len(loss_boundery_dice)
            msg = ', '.join([
                'it: {it}/{max_it}',
                'lr: {lr:4f}',
                'loss: {loss:.4f}',
                'boundery_bce_loss: {boundery_bce_loss:.4f}',
                'boundery_dice_loss: {boundery_dice_loss:.4f}',
                'eta: {eta}',
                'time: {time:.4f}',
            ]).format(
                it = it+1,
                max_it = max_iter,
                lr = lr,
                loss = loss_avg,
                boundery_bce_loss = loss_boundery_bce_avg,
                boundery_dice_loss = loss_boundery_dice_avg,
                time = t_intv,
                eta = eta
            )
            
            logger.info(msg)
            loss_avg = []
            loss_boundery_bce = []
            loss_boundery_dice = []
            st = ed
            # print(boundary_loss_func.get_params())
        if (it+1)%save_iter_sep==0:# and it != 0:
            
            ## model
            logger.info('evaluating the model ...')
            logger.info('setup and restore model')
            
            net.eval()

            # ## evaluator
            logger.info('compute the mIOU')
            with torch.no_grad():
                single_scale1 = MscEvalV0()
                mIOU50 = single_scale1(net, dlval, n_classes)

                single_scale2= MscEvalV0(scale=0.75)
                mIOU75 = single_scale2(net, dlval, n_classes)


            save_pth = osp.join(save_pth_path, 'model_iter{}_mIOU50_{}_mIOU75_{}.pth'
            .format(it+1, str(round(mIOU50,4)), str(round(mIOU75,4))))
            
            state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
            if dist.get_rank()==0: 
                torch.save(state, save_pth)

            logger.info('training iteration {}, model saved to: {}'.format(it+1, save_pth))

            if mIOU50 > maxmIOU50:
                maxmIOU50 = mIOU50
                save_pth = osp.join(save_pth_path, 'model_maxmIOU50.pth'.format(it+1))
                state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
                if dist.get_rank()==0: 
                    torch.save(state, save_pth)
                    
                logger.info('max mIOU model saved to: {}'.format(save_pth))
            
            if mIOU75 > maxmIOU75:
                maxmIOU75 = mIOU75
                save_pth = osp.join(save_pth_path, 'model_maxmIOU75.pth'.format(it+1))
                state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
                if dist.get_rank()==0: torch.save(state, save_pth)
                logger.info('max mIOU model saved to: {}'.format(save_pth))
            
            logger.info('mIOU50 is: {}, mIOU75 is: {}'.format(mIOU50, mIOU75))
            logger.info('maxmIOU50 is: {}, maxmIOU75 is: {}.'.format(maxmIOU50, maxmIOU75))

            net.train()
    
    ## dump the final model
    save_pth = osp.join(save_pth_path, 'model_final.pth')
    net.cpu()
    state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
    if dist.get_rank()==0: torch.save(state, save_pth)
    logger.info('training done, model saved to: {}'.format(save_pth))
    print('epoch: ', epoch)

if __name__ == "__main__":
    train()

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
STDC是一种基于C语言标准库的模块架构,它的核心思想是将代码按照功能模块划分成多个小模块,每个小模块负责处理一个特定的功能,这些小模块可以独立编译、测试和部署。STDC的模块架构可以让开发者更加高效地组织代码,提高代码的可维护性和可重用性。 下面是一个简单的STDC模块架构的示例: ``` stdc/ ├── include/ │ ├── module1.h │ ├── module2.h │ ├── ... ├── src/ │ ├── module1.c │ ├── module2.c │ ├── ... ├── Makefile ├── README.md ``` 在这个示例中,`include`目录包含了所有模块的头文件,`src`目录包含了所有模块的源代码文件。`Makefile`文件用于编译和链接所有模块的代码,生成可执行文件或库文件。 下面是一个简单的STDC模块架构的Python代码实现示例: ``` stdc/ ├── __init__.py ├── module1/ │ ├── __init__.py │ ├── module1.py ├── module2/ │ ├── __init__.py │ ├── module2.py ├── ... ``` 在这个示例中,STDC模块架构被用于Python代码的组织。每个功能模块都被封装在一个独立的目录中,目录下包含了一个`__init__.py`文件和一个或多个Python源代码文件。`__init__.py`文件用于定义模块的接口和导出需要暴露的函数和变量。 例如,`module1.py`文件可以定义一个名为`func1`的函数: ```python def func1(): print("This is module1's func1") ``` `__init__.py`文件可以将`func1`函数导出: ```python from .module1 import func1 ``` 这样,在其他Python模块中,可以使用以下语句导入`module1`模块并调用`func1`函数: ```python from stdc.module1 import func1 func1() ``` 总的来说,STDC模块架构的原理是将代码按照功能模块划分成多个小模块,每个小模块负责处理一个特定的功能,这些小模块可以独立编译、测试和部署。在Python中,可以通过创建独立的目录和`__init__.py`文件来实现STDC模块架构。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

想要躺平的一枚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值