NeWCRFs

  1. 训练命令:
python newcrfs/train.py configs/arguments_train_nyu.txt
  1. main
if __name__ == '__main__':
    main()
  1. main()
def main():
    if args.mode != 'train':  # 判断运行模式,不是train就退出
    	...

    command = 'mkdir ' + os.path.join(args.log_directory, args.model_name)  # mkdir ./models/newcrfs_nyu
    os.system(command)
    args_out_path = os.path.join(args.log_directory, args.model_name)  # args_out_path = './models/newcrfs_nyu'
    command = 'cp ' + sys.argv[1] + ' ' + args_out_path  # cp configs/arguments_train_nyu.txt ./models/newcrfs_nyu
    os.system(command)

    save_files = True
    if save_files:
        aux_out_path = os.path.join(args.log_directory, args.model_name)  # aux_out_path = './models/newcrfs_nyu'
        # networks_savepath = './models/newcrfs_nyu/networks'
        networks_savepath = os.path.join(aux_out_path, 'networks')
        # dataloaders_savepath = './models/newcrfs_nyu/dataloaders'
        dataloaders_savepath = os.path.join(aux_out_path, 'dataloaders')
        command = 'cp newcrfs/train.py ' + aux_out_path  # cp newcrf/train.py ./models/newcrfs_nyu
        os.system(command)
        # 保存所有的训练文件到 ./models/newcrfs_nyu/networks 文件夹中
        # mkdir -p ./models/newcrfs_nyu/networks && cp newcrfs/networks/*.py ./models/newcrfs_nyu/networks
        command = 'mkdir -p ' + networks_savepath + ' && cp newcrfs/networks/*.py ' + networks_savepath
        os.system(command)
        # 保存所有的数据读取文件到 ./models/newcrfs_nyu/dataloaders 文件夹中
        # mkdir -p ./models/newcrfs_nyu/dataloaders && cp newcrfs/dataloaders/*.py ./models/newcrfs_nyu/dataloaders
        command = 'mkdir -p ' + dataloaders_savepath + ' && cp newcrfs/dataloaders/*.py ' + dataloaders_savepath
        os.system(command)

    torch.cuda.empty_cache()
    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if ngpus_per_node > 1 and not args.multiprocessing_distributed:
        print("This machine has more than 1 gpu. Please specify --multiprocessing_distributed, or set \'CUDA_VISIBLE_DEVICES=0\'")
        return -1

    if args.do_online_eval:
        print("You have specified --do_online_eval.")
        print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics."
              .format(args.eval_freq))  # step 1000 测试一次

    if args.multiprocessing_distributed:  # 分布式代码
        args.world_size = ngpus_per_node * args.world_size
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        main_worker(args.gpu, ngpus_per_node, args)
  1. main_worker(0, 1, args)
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu  # 0
    if args.gpu is not None: print("== Use GPU: {} for training".format(args.gpu))  # == Use GPU: {0} for training
    if args.distributed:  # False 不适用分布式训练
		...
		
    # NeWCRFs model --version large07 --inv_depth False --max_depth 10 --pretrain swin_large_patch4_window7_224_22k.pth
    model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=args.pretrain)
    model.train()

    # 计算参数总量和更新后的参数量
    ...

    if args.distributed:  # False
        ...
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()
        print("== Model Initialized")

    global_step = 0
    best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3  # 越低越好的6个参数,Abs Rel等
    best_eval_measures_higher_better = torch.zeros(3).cpu()  # 越高越好的3个参数,Delta1等
    best_eval_steps = np.zeros(9, dtype=np.int32)

    # Training parameters
    optimizer = torch.optim.Adam([{'params': model.module.parameters()}],
                                lr=args.learning_rate)

    model_just_loaded = False
    if args.checkpoint_path != '':  # 中断后重新训练代码
        if os.path.isfile(args.checkpoint_path):
            print("== Loading checkpoint '{}'".format(args.checkpoint_path))
            if args.gpu is None:
                checkpoint = torch.load(args.checkpoint_path)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.checkpoint_path, map_location=loc)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if not args.retrain:
                try:
                    global_step = checkpoint['global_step']
                    best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu()
                    best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu()
                    best_eval_steps = checkpoint['best_eval_steps']
                except KeyError:
                    print("Could not load values for online evaluation")

            print("== Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step']))
        else:
            print("== No checkpoint found at '{}'".format(args.checkpoint_path))
        model_just_loaded = True
        del checkpoint

    cudnn.benchmark = True
    dataloader = NewDataLoader(args, 'train')
    dataloader_eval = NewDataLoader(args, 'online_eval')

    # Logging
    if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
    	# writer = SummaryWriter('./models/newcrfs_nyu/summaries')
        writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30)
        if args.do_online_eval:
            if args.eval_summary_directory != '':
                eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name)
            else:
            	# ./models/newcrfs_nyu/eval
                eval_summary_path = os.path.join(args.log_directory, args.model_name, 'eval')  
            eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30)

    silog_criterion = silog_loss(variance_focus=args.variance_focus)  # 0.85

    start_time = time.time()
    duration = 0

    num_log_images = args.batch_size  # 8
    end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate

    var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
    var_cnt = len(var_sum)
    var_sum = np.sum(var_sum)
    print("== Initial variables' sum: {:.3f}, avg: {:.3f}".format(var_sum, var_sum/var_cnt))

    steps_per_epoch = len(dataloader.data)  # dataset.size / batch.size
    num_total_steps = args.num_epochs * steps_per_epoch  
    epoch = global_step // steps_per_epoch

    while epoch < args.num_epochs:
        for step, sample_batched in enumerate(dataloader.data):
            optimizer.zero_grad()
            before_op_time = time.time()

            image = torch.autograd.Variable(sample_batched['image'].cuda(args.gpu, non_blocking=True))
            depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(args.gpu, non_blocking=True))
            depth_est = model(image)

            if args.dataset == 'nyu':
                mask = depth_gt > 0.1
            else:
                mask = depth_gt > 1.0

            loss = silog_criterion.forward(depth_est, depth_gt, mask.to(torch.bool))
            loss.backward()
            for param_group in optimizer.param_groups:
                current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate
                param_group['lr'] = current_lr

            optimizer.step()

            if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                print('[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'.format(epoch, step, steps_per_epoch, global_step, current_lr, loss))
                if np.isnan(loss.cpu().item()):
                    print('NaN in loss occurred. Aborting training.')
                    return -1

            duration += time.time() - before_op_time
            if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
                var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
                var_cnt = len(var_sum)
                var_sum = np.sum(var_sum)
                examples_per_sec = args.batch_size / duration * args.log_freq
                duration = 0
                time_sofar = (time.time() - start_time) / 3600
                training_time_left = (num_total_steps / global_step - 1.0) * time_sofar
                if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                    print("{}".format(args.model_name))
                print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
                print(print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item()/var_cnt, time_sofar, training_time_left))

                if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                                                            and args.rank % ngpus_per_node == 0):
                    writer.add_scalar('silog_loss', loss, global_step)
                    writer.add_scalar('learning_rate', current_lr, global_step)
                    writer.add_scalar('var average', var_sum.item()/var_cnt, global_step)
                    depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3, depth_gt)
                    for i in range(num_log_images):
                        writer.add_image('depth_gt/image/{}'.format(i), normalize_result(1/depth_gt[i, :, :, :].data), global_step)
                        writer.add_image('depth_est/image/{}'.format(i), normalize_result(1/depth_est[i, :, :, :].data), global_step)
                        writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)
                    writer.flush()

            if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
                time.sleep(0.1)
                model.eval()
                with torch.no_grad():
                    eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True)
                if eval_measures is not None:
                    for i in range(9):
                        eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step))
                        measure = eval_measures[i]
                        is_best = False
                        if i < 6 and measure < best_eval_measures_lower_better[i]:
                            old_best = best_eval_measures_lower_better[i].item()
                            best_eval_measures_lower_better[i] = measure.item()
                            is_best = True
                        elif i >= 6 and measure > best_eval_measures_higher_better[i-6]:
                            old_best = best_eval_measures_higher_better[i-6].item()
                            best_eval_measures_higher_better[i-6] = measure.item()
                            is_best = True
                        if is_best:
                            old_best_step = best_eval_steps[i]
                            old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best)
                            model_path = args.log_directory + '/' + args.model_name + old_best_name
                            if os.path.exists(model_path):
                                command = 'rm {}'.format(model_path)
                                os.system(command)
                            best_eval_steps[i] = global_step
                            model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure)
                            print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name))
                            checkpoint = {'global_step': global_step,
                                          'model': model.state_dict(),
                                          'optimizer': optimizer.state_dict(),
                                          'best_eval_measures_higher_better': best_eval_measures_higher_better,
                                          'best_eval_measures_lower_better': best_eval_measures_lower_better,
                                          'best_eval_steps': best_eval_steps
                                          }
                            torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name)
                    eval_summary_writer.flush()
                model.train()
                block_print()
                enable_print()

            model_just_loaded = False
            global_step += 1

        epoch += 1
       
    if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
        writer.close()
        if args.do_online_eval:
            eval_summary_writer.close()
  1. model = NewCRFDepth(version=large07, inv_depth=False, max_depth=10, pretrained=swin_large_patch4_window7_224_22k.pth)
class NewCRFDepth(nn.Module):
    """
    Depth network based on neural window FC-CRFs architecture.
    """
    def __init__(self, version=large07, inv_depth=False, pretrained=swin_large_patch4_window7_224_22k.pth, 
                    frozen_stages=-1, min_depth=0.1, max_depth=10.0, **kwargs):
        super().__init__()

        self.inv_depth = False
        self.with_auxiliary_head = False
        self.with_neck = False
        norm_cfg = dict(type='BN', requires_grad=True)
        window_size = int(version[-2:])

        if version[:-2] == 'base':
            ...
        elif version[:-2] == 'large':
            embed_dim = 192
            depths = [2, 2, 18, 2]
            num_heads = [6, 12, 24, 48]
            in_channels = [192, 384, 768, 1536]
        elif version[:-2] == 'tiny':
            ...

        backbone_cfg = dict(
            embed_dim=embed_dim,  # 192
            depths=depths,  # [2, 2, 18, 2]
            num_heads=num_heads,  # [6, 12, 24, 48]
            window_size=window_size,  # 07
            ape=False,
            drop_path_rate=0.3,
            patch_norm=True,
            use_checkpoint=False,
            frozen_stages=frozen_stages  # -1
        )

        embed_dim = 512
        decoder_cfg = dict(
            in_channels=in_channels,  # [192, 384, 768, 1536]
            in_index=[0, 1, 2, 3],
            pool_scales=(1, 2, 3, 6),
            channels=embed_dim,  # 512
            dropout_ratio=0.0,
            num_classes=32,
            norm_cfg=norm_cfg,  # dict(type='BN', requires_grad=True)
            align_corners=False
        )

        self.backbone = SwinTransformer(**backbone_cfg)
        v_dim = decoder_cfg['num_classes']*4  # 128
        win = 7
        crf_dims = [128, 256, 512, 1024]
        v_dims = [64, 128, 256, 512]
        self.crf3 = NewCRF(input_dim=1536, embed_dim=1024, window_size=7, v_dim=512, num_heads=32)
        self.crf2 = NewCRF(input_dim=768,  embed_dim=512,  window_size=7, v_dim=256, num_heads=16)
        self.crf1 = NewCRF(input_dim=384,  embed_dim=256,  window_size=7, v_dim=128, num_heads=8)
        self.crf0 = NewCRF(input_dim=192,  embed_dim=128,  window_size=7, v_dim=64,  num_heads=4)

        self.decoder = PSP(**decoder_cfg)
        self.disp_head1 = DispHead(input_dim=128)

        self.up_mode = 'bilinear'
        self.min_depth = 0.1
        self.max_depth = 10

        self.init_weights(pretrained=pretrained)

    def init_weights(self, pretrained=None):
        print(f'== Load encoder backbone from: {pretrained}')
        self.backbone.init_weights(pretrained=pretrained)  # swin_large_patch4_window7_224_22k.pth
        self.decoder.init_weights()

    def forward(self, imgs):
        feats = self.backbone(imgs)  # [[56, 56, 192], [28, 28, 384], [14, 14, 768], [7, 7, 1536]]
        ppm_out = self.decoder(feats)  # [7, 7, 1536 + 512 * 4] -> [7, 7, 512]

        e3 = self.crf3(feats[3], ppm_out)
        e3 = nn.PixelShuffle(2)(e3)  # 上采样
        e2 = self.crf2(feats[2], e3)
        e2 = nn.PixelShuffle(2)(e2)  # 上采样
        e1 = self.crf1(feats[1], e2)
        e1 = nn.PixelShuffle(2)(e1)  # 上采样
        e0 = self.crf0(feats[0], e1)
		
		d1 = self.disp_head1(e0, 4)
        depth = d1 * self.max_depth

        return 
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值