深度补全-PENet_ICRA2021论文梳理和代码阅读记录

该研究提出了一种名为PENet的两分支网络,用于图像引导的深度完成任务。该网络包括一个颜色主导分支和一个深度主导分支,分别从颜色和深度信息中提取优势。通过几何卷积层编码3D几何线索,并结合CSPN++模块进行高效深度细化。在Kitti数据集上的实验表明,这种方法能够有效融合颜色和深度信息,提高深度图的精度。
摘要由CSDN通过智能技术生成

文章目录


论文作者地址:[https://github.com/JUGGHM/PENet_ICRA2021]

Towards Precise and Efficient Image Guided Depth Completion

two-branch network

  • color-dominant (CD) branch

    color image and a sparse depth map --------- predicted depth map

    为深度预测提取颜色主导的信息,预测的深度图在物体边界附近相对可靠,对颜色和纹理变换比较敏感

  • depth-dominant (DD) branch

    sparse depth map and the previously predicted depth map ----------- dense depth map

    这一分支总体上来说可靠,但是在输入的稀疏深度图中,物体边界噪声严重

  • 总的来说这两个分支的优缺点比较互补,因此用learned confidence weights来融合两个分支的结果,充分利用颜色和深度信息

  • propose a simple geometric convolutional layer----encode 3D geometric cues

    It simply augments a convolutional layer via concatenating a 3D position map to the layer’s input.

  • we additionally integrate a module based on CSPN++ to refine the depth map predicted by our backbone

    We design a dilated and accelerated implementation of CSPN++ to make the refinement more effective and efficient.

  • SOTA

Related Work
  • A. Depth Completion

    produce a dense depth map by completing a sparse depth map, without or with the guidance of a reference image

    challenges

    1. the input depth map is irregularly sparse and noisy;

    2. the color image and the depth map are two different modalities.

  • B. Geometric Encoding

    we propose a geometric convolutional layer to encode 3D geometric cues simply. (受CoordConv启发)

  • C. Spatial Propagation Networks

    spatial propagation network (SPN)

    convolutional spatial propagation network (CSPN)

    CSPN++ and NLSPN are proposed very recently.

    The former adaptively learns the convolutional kernel size and iteration number for propagation

    The latter learns deformable kernels.

    We adopt CSPN++ for our depth refinement, but we introduce a dilation scheme to enlarge the neighborhoods

    and implement the propagation in a much more efficient way

METHODOLOGY

entire framework---------two-branch backbone and a depth refinement module.

  • The Two-branch Backbone

    color-dominant branch

    1. an aligned sparse depth map is also input to assist depth prediction

    2. The encoder contains one convolution layer and ten basic residual blocks

    3. The decoder has five deconvolution layers and one convolution layer

      (卷积层后添加了BN和ReLu)

    depth-dominant branch

    1. the decoder features of the color-dominant branch are concatenated with the corresponding encoder features in the depth dominant branch. (多阶段融合)

    Depth fusion

在这里插入图片描述

  • The Geometric Convolutional Layer

    1. augments a conventional convolutional layer via concatenating a 3D position map to the layer’s input.

    2. we replace each convolutional layer within the ResBlocks by the proposed geometric convolutional layer.

  • The Dilated and Accelerated CSPN++

    recover the depth values at valid pixels

    introduce a dilation strategy similar to the well known dilated convolutions o enlarge the propagation neighborhoods.
    在这里插入图片描述

    our implementation of the translations can be performed parallelly.(more effificient)

实验

(https://github.com/JUGGHM/PENet_ICRA2021)

  • two-branch backbone

  • the geometric convolutional layer

  • the DA-CPSN++ module.

  • we obtain four variants of the backbone. B1 to B4

  • Based on the backbone model B4, we further replace each convolutional layer in the ResBlocks by our proposed geometric convolutional layer and get the model B4+GCL.

  • Based on the backbone model B4 , we integrate variants of CSPN++ to compare their performance.The total number of iterations for propagation is 12.

  • C1 stands for original CSPN++, with a dilation rate (dr) of 1 for all iterations.

  • C2 stands for the model that takes dr = 2 for first six iterations and dr = 1 for the remaining iterations.

  • C4 is the model taking dr = {4, 2, 1} for every four iterations

  • the model B4 +C2 slightly outperforms the other two counterparts

  • ENet: We also test our geometric encoded backbone without depth refinement (referred to as ENet)

  • PENet: we present the quantitative performance of our full method (referred to as PENet)

代码阅读

Pytorch中tensor的通道顺序:NCHW

不是很理解为什么这么降采样

class SparseDownSampleClose(nn.Module):
    def __init__(self, stride):
        super(SparseDownSampleClose, self).__init__()
        self.pooling = nn.MaxPool2d(stride, stride)
        self.large_number = 600
    def forward(self, d, mask):
        encode_d = - (1-mask)*self.large_number - d

        d = - self.pooling(encode_d)
        mask_result = self.pooling(mask)
        d_result = d - (1-mask_result)*self.large_number

        return d_result, mask_result

运行load_calib之后得到K

[[721.5377   0.     596.5593]
 [  0.     721.5377 161.354 ]
 [  0.       0.       1.    ]]

BasicBlockGeo就是把残差块中的卷积部分换成了GCL

class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
            #norm_layer = encoding.nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        if stride != 1 or inplanes != planes:
            downsample = nn.Sequential(
                conv1x1(inplanes, planes, stride),
                norm_layer(planes),
            )
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class BasicBlockGeo(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, geoplanes=3):
        super(BasicBlockGeo, self).__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
            #norm_layer = encoding.nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes + geoplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes+geoplanes, planes)
        self.bn2 = norm_layer(planes)
        if stride != 1 or inplanes != planes:
            downsample = nn.Sequential(
                conv1x1(inplanes+geoplanes, planes, stride),
                norm_layer(planes),
            )
        self.downsample = downsample
        self.stride = stride

    def forward(self, x, g1=None, g2=None):
        identity = x
        if g1 is not None:
            x = torch.cat((x, g1), 1)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        if g2 is not None:
            out = torch.cat((g2,out), 1)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

main.py

parser

  • 模型
  • workers
  • epochs
  • start-epoch
  • start-epoch-bias
  • criterion
  • batch-size
  • lr
  • weight-decay
  • print-freq
  • resume
  • data-folder
  • data-folder-rgb
  • data-folder-save
  • input
  • val
  • jitter
  • rank-metric
  • evaluate
  • freeze-backbone
  • test
  • cpu
  • not-random-crop
  • random-crop-height
  • random-crop-width
  • convolutional-layer-encoding
  • dilation-rate
args.use_rgb = ('rgb' in args.input)
args.use_d = 'd' in args.input
args.use_g = 'g' in args.input
args = parser.parse_args()
args.result = os.path.join('..', 'results')
args.val_h = 352
args.val_w = 1216
print(args)

cuda or gpu

cuda = torch.cuda.is_available() and not args.cpu
if cuda:
    import torch.backends.cudnn as cudnn
    cudnn.benchmark = True
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("=> using '{}' for computation.".format(device))

loss function

depth_criterion = criteria.MaskedMSELoss() if (
    args.criterion == 'l2') else criteria.MaskedL1Loss()

def iterate(mode, args, loader, model, optimizer, logger, epoch):

 assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)

def main():

def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            #args = checkpoint['args']
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            is_eval = True

            print("Completed.")
        else:
            is_eval = True
            print("No model found at '{}'".format(args.evaluate))
            #return

    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)

            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = None
    penet_accelerated = False
    if (args.network_model == 'e'):
        model = ENet(args).to(device)
    elif (is_eval == False):
        if (args.dilation_rate == 1):
            model = PENet_C1_train(args).to(device)
        elif (args.dilation_rate == 2):
            model = PENet_C2_train(args).to(device)
        elif (args.dilation_rate == 4):
            model = PENet_C4(args).to(device)
            penet_accelerated = True
    else:
        if (args.dilation_rate == 1):
            model = PENet_C1(args).to(device)
            penet_accelerated = True
        elif (args.dilation_rate == 2):
            model = PENet_C2(args).to(device)
            penet_accelerated = True
        elif (args.dilation_rate == 4):
            model = PENet_C4(args).to(device)
            penet_accelerated = True

    if (penet_accelerated == True):
        model.encoder3.requires_grad = False
        model.encoder5.requires_grad = False
        model.encoder7.requires_grad = False

    model_named_params = None
    model_bone_params = None
    model_new_params = None
    optimizer = None

    if checkpoint is not None:
        #print(checkpoint.keys())
        if (args.freeze_backbone == True):
            model.backbone.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'], strict=False)
        #optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
        del checkpoint
    print("=> logger created.")

    test_dataset = None
    test_loader = None
    if (args.test):
        test_dataset = KittiDepth('test_completion', args)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=1,
            pin_memory=True)
        iterate("test_completion", args, test_loader, model, None, logger, 0)
        return

    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    if is_eval == True:
        for p in model.parameters():
            p.requires_grad = False

        result, is_best = iterate("val", args, val_loader, model, None, logger,
                              args.start_epoch - 1)
        return

    if (args.freeze_backbone == True):
        for p in model.backbone.parameters():
            p.requires_grad = False
        model_named_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99))
    elif (args.network_model == 'pe'):
        model_bone_params = [
            p for _, p in model.backbone.named_parameters() if p.requires_grad
        ]
        model_new_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        model_new_params = list(set(model_new_params) - set(model_bone_params))
        optimizer = torch.optim.Adam([{'params': model_bone_params, 'lr': args.lr / 10}, {'params': model_new_params}],
                                     lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99))
    else:
        model_named_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99))
    print("completed.")

    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))

    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger, epoch)  # train for one epoch

        # validation memory reset
        for p in model.parameters():
            p.requires_grad = False
        result, is_best = iterate("val", args, val_loader, model, None, logger, epoch)  # evaluate on validation set

        for p in model.parameters():
            p.requires_grad = True
        if (args.freeze_backbone == True):
            for p in model.module.backbone.parameters():
                p.requires_grad = False
        if (penet_accelerated == True):
            model.module.encoder3.requires_grad = False
            model.module.encoder5.requires_grad = False
            model.module.encoder7.requires_grad = False

        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)


kitti_loaders.py

input_options = ['d', 'rgb', 'rgbd', 'g', 'gd']

def load_calib():

def load_calib():
   使用2011_09_26的calib文件临时硬编码校准矩阵

def get_paths_and_transform

transform = no_transform
glob_d = os.path.join( args.data_folder,"data_depth_selection/test_depth_completion_anonymous/velodyne_raw/*.png")
glob_gt = None  # "test_depth_completion_anonymous/"
        glob_rgb = os.path.join( args.data_folder,"data_depth_selection/test_depth_completion_anonymous/image/*.png")
        paths_rgb = sorted(glob.glob(glob_rgb))
        paths_gt = [None] * len(paths_rgb)
        paths_d = sorted(glob.glob(glob_d))
        paths = {"rgb": paths_rgb, "d": paths_d, "gt": paths_gt}
        return paths, transform

def rgb_read(filename):

def rgb_read(filename):
    assert os.path.exists(filename), "file not found: {}".format(filename)
    img_file = Image.open(filename)
    # rgb_png = np.array(img_file, dtype=float) / 255.0 # scale pixels to the range [0,1]
    rgb_png = np.array(img_file, dtype='uint8')  # in the range [0,255]
    img_file.close()
    return rgb_png

def depth_read(filename):

def depth_read(filename):
    # loads depth map D from png file
    # and returns it as a numpy array,
    # for details see readme.txt
    assert os.path.exists(filename), "file not found: {}".format(filename)
    img_file = Image.open(filename)
    depth_png = np.array(img_file, dtype=int)
    img_file.close()
    # make sure we have a proper 16bit depth map here.. not 8bit!
    assert np.max(depth_png) > 255, \
        "np.max(depth_png)={}, path={}".format(np.max(depth_png), filename)
    depth = depth_png.astype(np.float) / 256.
    # depth[depth_png == 0] = -1.
    depth = np.expand_dims(depth, -1)
    return depth

def drop_depth_measurements

def drop_depth_measurements(depth, prob_keep):
    mask = np.random.binomial(1, prob_keep, depth.shape)
    depth *= mask
    return depth

def train_transform(rgb, sparse, target, position, args): 对原始图片进行批处理

def train_transform(rgb, sparse, target, position, args):
	 oheight = args.val_h
    owidth = args.val_w

    do_flip = np.random.uniform(0.0, 1.0) < 0.5  # random horizontal flip

    transforms_list = [
        # transforms.Rotate(angle),
        # transforms.Resize(s),
        transforms.BottomCrop((oheight, owidth)),
        transforms.HorizontalFlip(do_flip)
        
    ]
    transform_geometric = transforms.Compose(transforms_list) 
    ##这个类的主要作用是串联多个图片变换的操作
     if sparse is not None:
        sparse = transform_geometric(sparse)
    target = transform_geometric(target)
    if rgb is not None:
        brightness = np.random.uniform(max(0, 1 - args.jitter),
                                       1 + args.jitter)
        contrast = np.random.uniform(max(0, 1 - args.jitter), 1 + args.jitter)
        saturation = np.random.uniform(max(0, 1 - args.jitter),
                                       1 + args.jitter)
        transform_rgb = transforms.Compose([
            transforms.ColorJitter(brightness, contrast, saturation, 0),
            transform_geometric
        ])
        rgb = transform_rgb(rgb)
    if position is not None:
        bottom_crop_only = transforms.Compose([transforms.BottomCrop((oheight, owidth))])
        position = bottom_crop_only(position)

   ###裁剪
    if args.not_random_crop == False:
        h = oheight
        w = owidth
        rheight = args.random_crop_height
        rwidth = args.random_crop_width
        # randomlize
        i = np.random.randint(0, h - rheight + 1)
        j = np.random.randint(0, w - rwidth + 1)

        if rgb is not None:
            if rgb.ndim == 3:
                rgb = rgb[i:i + rheight, j:j + rwidth, :]
            elif rgb.ndim == 2:
                rgb = rgb[i:i + rheight, j:j + rwidth]

        if sparse is not None:
            if sparse.ndim == 3:
                sparse = sparse[i:i + rheight, j:j + rwidth, :]
            elif sparse.ndim == 2:
                sparse = sparse[i:i + rheight, j:j + rwidth]

        if target is not None:
            if target.ndim == 3:
                target = target[i:i + rheight, j:j + rwidth, :]
            elif target.ndim == 2:
                target = target[i:i + rheight, j:j + rwidth]

        if position is not None:
            if position.ndim == 3:
                position = position[i:i + rheight, j:j + rwidth, :]
            elif position.ndim == 2:
                position = position[i:i + rheight, j:j + rwidth]

    return rgb, sparse, target, position

def val_transform(rgb, sparse, target, position, args):

def val_transform(rgb, sparse, target, position, args):
    oheight = args.val_h
    owidth = args.val_w

    transform = transforms.Compose([
        transforms.BottomCrop((oheight, owidth)),
    ])
    if rgb is not None:
        rgb = transform(rgb)
    if sparse is not None:
        sparse = transform(sparse)
    if target is not None:
        target = transform(target)
    if position is not None:
        position = transform(position)

    return rgb, sparse, target, position

def no_transform(rgb, sparse, target, position, args):

def no_transform(rgb, sparse, target, position, args):
    return rgb, sparse, target, position

to_float_tensor

to_tensor = transforms.ToTensor()
to_float_tensor = lambda x: to_tensor(x).float()

def handle_gray(rgb, args): 转化成灰度图

def handle_gray(rgb, args):
    if rgb is None:
        return None, None
    if not args.use_g:
        return rgb, None
    else:
        img = np.array(Image.fromarray(rgb).convert('L'))
        img = np.expand_dims(img, -1)
        if not args.use_rgb:
            rgb_ret = None
        else:
            rgb_ret = rgb
        return rgb_ret, img
img.convert(‘L’)
 img.convert('L')
   为灰度图像,每个像素用8个bit表示,0表示黑,255表示白,其他数字表示不同的灰度。
   转换公式:L = R * 299/1000 + G * 587/1000+ B * 114/1000

def get_rgb_near(path, args):

def get_rgb_near(path, args):
    assert path is not None, "path is None"

    def extract_frame_id(filename):
        head, tail = os.path.split(filename)
        number_string = tail[0:tail.find('.')]
        number = int(number_string)
        return head, number

    def get_nearby_filename(filename, new_id):
        head, _ = os.path.split(filename)
        new_filename = os.path.join(head, '%010d.png' % new_id)
        return new_filename

    head, number = extract_frame_id(path)
    count = 0
    max_frame_diff = 3
    candidates = [
        i - max_frame_diff for i in range(max_frame_diff * 2 + 1)
        if i - max_frame_diff != 0
    ]
    while True:
        random_offset = choice(candidates)
        path_near = get_nearby_filename(path, number + random_offset)
        if os.path.exists(path_near):
            break
        assert count < 20, "cannot find a nearby frame in 20 trials for {}".format(path_near)

    return rgb_read(path_near)

class KittiDepth(data.Dataset):

candidates = {"rgb": rgb, "d": sparse, "gt": target, \
              "g": gray, 'position': position, 'K': self.K}
        items = {
        key: to_float_tensor(val)
        for key, val in candidates.items() if val is not None
    }
    return items

都转成tensor

class KittiDepth(data.Dataset):
    """A data loader for the Kitti dataset
    """

    def __init__(self, split, args):
        self.args = args
        self.split = split
        paths, transform = get_paths_and_transform(split, args)
        self.paths = paths
        self.transform = transform
        self.K = load_calib()
        self.threshold_translation = 0.1

    def __getraw__(self, index):
        rgb = rgb_read(self.paths['rgb'][index]) if \
            (self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None
        sparse = depth_read(self.paths['d'][index]) if \
            (self.paths['d'][index] is not None and self.args.use_d) else None
        target = depth_read(self.paths['gt'][index]) if \
            self.paths['gt'][index] is not None else None
        return rgb, sparse, target

    def __getitem__(self, index):
        rgb, sparse, target = self.__getraw__(index)
        position = CoordConv.AddCoordsNp(self.args.val_h, self.args.val_w)
        position = position.call()
        rgb, sparse, target, position = self.transform(rgb, sparse, target, position, self.args)

        rgb, gray = handle_gray(rgb, self.args)
        # candidates = {"rgb": rgb, "d": sparse, "gt": target, \
        #              "g": gray, "r_mat": r_mat, "t_vec": t_vec, "rgb_near": rgb_near}
        candidates = {"rgb": rgb, "d": sparse, "gt": target, \
                      "g": gray, 'position': position, 'K': self.K}

        items = {
            key: to_float_tensor(val)
            for key, val in candidates.items() if val is not None
        }

        return items

    def __len__(self):
        return len(self.paths['gt'])

CoordCov.py

  • 添加 x,y layer

criteria.py

  • 选择损失函数 L1, L2

helper.py

fieldnames = [
    'epoch', 'rmse', 'photo', 'mae', 'irmse', 'imae', 'mse', 'absrel', 'lg10',
    'silog', 'squared_rel', 'delta1', 'delta2', 'delta3', 'data_time',
    'gpu_time'
]

class logger:
-output_directory
-best_result
-self.train_csv = os.path.join(output_directory, 'train.csv')
 self.val_csv = os.path.join(output_directory, 'val.csv')
 self.best_txt = os.path.join(output_directory, 'best.txt')
#backup the source code
print("=> creating source code backup ...")
backup_directory = os.path.join(output_directory, "code_backup")
self.backup_directory = backup_directory
backup_source_code(backup_directory)
with open(self.train_csv, 'w') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
with open(self.val_csv, 'w') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
print("=> finished creating source code backup.")



def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter, avg_meter):
        if (i + 1) % self.args.print_freq == 0:
            avg = avg_meter.average()
            blk_avg = blk_avg_meter.average()
            print('=> output: {}'.format(self.output_directory))
            print(
                '{split} Epoch: {0} [{1}/{2}]\tlr={lr} '
                't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) '
                't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) '
                'MAE={blk_avg.mae:.2f}({average.mae:.2f}) '
                'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) '
                'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t'
                'silog={blk_avg.silog:.2f}({average.silog:.2f}) '
                'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) '
                'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) '
                'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t'
                'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) '
                'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) '
                .format(epoch,
                        i + 1,
                        n_set,
                        lr=lr,
                        blk_avg=blk_avg,
                        average=avg,
                        split=split.capitalize()))
            blk_avg_meter.reset(False)
### 写csv文件
def conditional_save_info(self, split, average_meter, epoch):
avg = average_meter.average()
        if split == "train":
            csvfile_name = self.train_csv
        elif split == "val":
            csvfile_name = self.val_csv
        elif split == "eval":
            eval_filename = os.path.join(self.output_directory, 'eval.txt')
            self.save_single_txt(eval_filename, avg, epoch)
            return avg
        elif "test" in split:
            return avg
        else:
            raise ValueError("wrong split provided to logger")
        with open(csvfile_name, 'a') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writerow({
                'epoch': epoch,
                'rmse': avg.rmse,
                'photo': avg.photometric,
                'mae': avg.mae,
                'irmse': avg.irmse,
                'imae': avg.imae,
                'mse': avg.mse,
                'silog': avg.silog,
                'squared_rel': avg.squared_rel,
                'absrel': avg.absrel,
                'lg10': avg.lg10,
                'delta1': avg.delta1,
                'delta2': avg.delta2,
                'delta3': avg.delta3,
                'gpu_time': avg.gpu_time,
                'data_time': avg.data_time
            })
        return avg

### 写txt文件
    def save_single_txt(self, filename, result, epoch):
        with open(filename, 'w') as txtfile:
            txtfile.write(
                ("rank_metric={}\n" + "epoch={}\n" + "rmse={:.3f}\n" +
                 "mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" +
                 "irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" +
                 "absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" +
                 "t_gpu={:.4f}").format(self.args.rank_metric, epoch,
                                        result.rmse, result.mae, result.silog,
                                        result.squared_rel, result.irmse,
                                        result.imae, result.mse, result.absrel,
                                        result.lg10, result.delta1,
                                        result.gpu_time))
    def save_best_txt(self, result, epoch):
        self.save_single_txt(self.best_txt, result, epoch)
        


def adjust_learning_rate(lr_init, optimizer, epoch, args):

def adjust_learning_rate(lr_init, optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
    #lr = lr_init * (0.5**(epoch // 5))
    #'''
    lr = lr_init
    if (args.network_model == 'pe' and args.freeze_backbone == False):
        if (epoch >= 10):
            lr = lr_init * 0.5
        if (epoch >= 20):
            lr = lr_init * 0.1
        if (epoch >= 30):
            lr = lr_init * 0.01
        if (epoch >= 40):
            lr = lr_init * 0.0005
        if (epoch >= 50):
            lr = lr_init * 0.00001
    else:
        if (epoch >= 10):
            lr = lr_init * 0.5
        if (epoch >= 15):
            lr = lr_init * 0.1
        if (epoch >= 25):
            lr = lr_init * 0.01
    #'''

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def save_checkpoint(state, is_best, epoch, output_directory):

def save_checkpoint(state, is_best, epoch, output_directory):
    checkpoint_filename = os.path.join(output_directory,
                                       'checkpoint-' + str(epoch) + '.pth.tar')
    torch.save(state, checkpoint_filename)
    if is_best:
        best_filename = os.path.join(output_directory, 'model_best.pth.tar')
        shutil.copyfile(checkpoint_filename, best_filename)
    if epoch > 0:
        prev_checkpoint_filename = os.path.join(
            output_directory, 'checkpoint-' + str(epoch - 1) + '.pth.tar')
        if os.path.exists(prev_checkpoint_filename):
            os.remove(prev_checkpoint_filename)

def backup_source_code(backup_directory):

  • 备份代码函数的定义
ignore_hidden = shutil.ignore_patterns(".", "..", ".git*", "*pycache*",
                                       "*build", "*.fuse*", "*_drive_*")
def backup_source_code(backup_directory):
    if os.path.exists(backup_directory):
        shutil.rmtree(backup_directory)
    shutil.copytree('.', backup_directory, ignore=ignore_hidden)

def get_folder_name(args):

def get_folder_name(args):
    current_time = time.strftime('%Y-%m-%d@%H-%M')
    return os.path.join(args.result,
        'input={}.criterion={}.lr={}.bs={}.wd={}.jitter={}.time={}'.
        format(args.input, args.criterion, \
            args.lr, args.batch_size, args.weight_decay, \
            args.jitter, current_time
            ))

metrics.py

class Result

class Result(object):
    def __init__(self):
        self.irmse = 0
        self.imae = 0
        self.mse = 0
        self.rmse = 0
        self.mae = 0
        self.absrel = 0
        self.squared_rel = 0
        self.lg10 = 0
        self.delta1 = 0
        self.delta2 = 0
        self.delta3 = 0
        self.data_time = 0
        self.gpu_time = 0
        self.silog = 0  # Scale invariant logarithmic error [log(m)*100]
        self.photometric = 0

    def set_to_worst(self):
        self.irmse = np.inf
        self.imae = np.inf
        self.mse = np.inf
        self.rmse = np.inf
        self.mae = np.inf
        self.absrel = np.inf
        self.squared_rel = np.inf
        self.lg10 = np.inf
        self.silog = np.inf
        self.delta1 = 0
        self.delta2 = 0
        self.delta3 = 0
        self.data_time = 0
        self.gpu_time = 0

    def update(self, irmse, imae, mse, rmse, mae, absrel, squared_rel, lg10, \
            delta1, delta2, delta3, gpu_time, data_time, silog, photometric=0):
        self.irmse = irmse
        self.imae = imae
        self.mse = mse
        self.rmse = rmse
        self.mae = mae
        self.absrel = absrel
        self.squared_rel = squared_rel
        self.lg10 = lg10
        self.delta1 = delta1
        self.delta2 = delta2
        self.delta3 = delta3
        self.data_time = data_time
        self.gpu_time = gpu_time
        self.silog = silog
        self.photometric = photometric

    def evaluate(self, output, target, photometric=0):
        valid_mask = target > 0.1

        # convert from meters to mm
        output_mm = 1e3 * output[valid_mask]
        target_mm = 1e3 * target[valid_mask]

        abs_diff = (output_mm - target_mm).abs()

        self.mse = float((torch.pow(abs_diff, 2)).mean())
        self.rmse = math.sqrt(self.mse)
        self.mae = float(abs_diff.mean())
        self.lg10 = float((log10(output_mm) - log10(target_mm)).abs().mean())
        self.absrel = float((abs_diff / target_mm).mean())
        self.squared_rel = float(((abs_diff / target_mm)**2).mean())

        maxRatio = torch.max(output_mm / target_mm, target_mm / output_mm)
        self.delta1 = float((maxRatio < 1.25).float().mean())
        self.delta2 = float((maxRatio < 1.25**2).float().mean())
        self.delta3 = float((maxRatio < 1.25**3).float().mean())
        self.data_time = 0
        self.gpu_time = 0

        # silog uses meters
        err_log = torch.log(target[valid_mask]) - torch.log(output[valid_mask])
        normalized_squared_log = (err_log**2).mean()
        log_mean = err_log.mean()
        self.silog = math.sqrt(normalized_squared_log -
                               log_mean * log_mean) * 100

        # convert from meters to km
        inv_output_km = (1e-3 * output[valid_mask])**(-1)
        inv_target_km = (1e-3 * target[valid_mask])**(-1)
        abs_inv_diff = (inv_output_km - inv_target_km).abs()
        self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
        self.imae = float(abs_inv_diff.mean())

        self.photometric = float(photometric)

class AverageMeter(object):

class AverageMeter(object):
    def __init__(self):
        self.reset(time_stable=True)

    def reset(self, time_stable):
        self.count = 0.0
        self.sum_irmse = 0
        self.sum_imae = 0
        self.sum_mse = 0
        self.sum_rmse = 0
        self.sum_mae = 0
        self.sum_absrel = 0
        self.sum_squared_rel = 0
        self.sum_lg10 = 0
        self.sum_delta1 = 0
        self.sum_delta2 = 0
        self.sum_delta3 = 0
        self.sum_data_time = 0
        self.sum_gpu_time = 0
        self.sum_photometric = 0
        self.sum_silog = 0
        self.time_stable = time_stable
        self.time_stable_counter_init = 10
        self.time_stable_counter = self.time_stable_counter_init

    def update(self, result, gpu_time, data_time, n=1):
        self.count += n
        self.sum_irmse += n * result.irmse
        self.sum_imae += n * result.imae
        self.sum_mse += n * result.mse
        self.sum_rmse += n * result.rmse
        self.sum_mae += n * result.mae
        self.sum_absrel += n * result.absrel
        self.sum_squared_rel += n * result.squared_rel
        self.sum_lg10 += n * result.lg10
        self.sum_delta1 += n * result.delta1
        self.sum_delta2 += n * result.delta2
        self.sum_delta3 += n * result.delta3
        self.sum_data_time += n * data_time
        if self.time_stable == True and self.time_stable_counter > 0:
            self.time_stable_counter = self.time_stable_counter - 1
        else:
            self.sum_gpu_time += n * gpu_time
        self.sum_silog += n * result.silog
        self.sum_photometric += n * result.photometric

    def average(self):
        avg = Result()
        if self.time_stable == True:
            if self.count > 0 and self.count - self.time_stable_counter_init > 0:
                avg.update(
                    self.sum_irmse / self.count, self.sum_imae / self.count,
                    self.sum_mse / self.count, self.sum_rmse / self.count,
                    self.sum_mae / self.count, self.sum_absrel / self.count,
                    self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
                    self.sum_delta1 / self.count, self.sum_delta2 / self.count,
                    self.sum_delta3 / self.count, self.sum_gpu_time / (self.count - self.time_stable_counter_init),
                    self.sum_data_time / self.count, self.sum_silog / self.count,
                    self.sum_photometric / self.count)
            elif self.count > 0:
                avg.update(
                    self.sum_irmse / self.count, self.sum_imae / self.count,
                    self.sum_mse / self.count, self.sum_rmse / self.count,
                    self.sum_mae / self.count, self.sum_absrel / self.count,
                    self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
                    self.sum_delta1 / self.count, self.sum_delta2 / self.count,
                    self.sum_delta3 / self.count, 0,
                    self.sum_data_time / self.count, self.sum_silog / self.count,
                    self.sum_photometric / self.count)
        elif self.count > 0:
            avg.update(
                self.sum_irmse / self.count, self.sum_imae / self.count,
                self.sum_mse / self.count, self.sum_rmse / self.count,
                self.sum_mae / self.count, self.sum_absrel / self.count,
                self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
                self.sum_delta1 / self.count, self.sum_delta2 / self.count,
                self.sum_delta3 / self.count, self.sum_gpu_time / self.count,
                self.sum_data_time / self.count, self.sum_silog / self.count,
                self.sum_photometric / self.count)
        return avg

vis_utils.py

def validcrop(img):

def depth_colorize(depth):

def feature_colorize(feature):

def mask_vis(mask):

def merge_into_row(ele, pred, predrgb=None, predg=None, extra=None, extra2=None, extrargb=None):

def add_row(img_merge, row):

def save_image(img_merge, filename):

def save_image_torch(rgb, filename):

def save_depth_as_uint16png(img, filename):

def save_depth_as_uint16png_upload(img, filename):

basic.py

  • 4
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 9
    评论
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值