Sparse and noisy LiDAR completion with RGB guidance and uncertainty代码

​​​​在这里插入图片描述
global使用的ERFNet。它使用的是Non-bottleneck模块来代替卷积模块。

在这里插入图片描述
在这里插入图片描述
ERF代码:

# ERFNet full model definition for Pytorch
# Sept 2017
# Eduardo Romera
#######################

import torch
import torch.nn as nn
import torch.nn.functional as F


class DownsamplerBlock (nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()

        self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True)
        self.pool = nn.MaxPool2d(2, stride=2)
        self.bn = nn.BatchNorm2d(noutput, eps=1e-3)

    def forward(self, input):
        output = torch.cat([self.conv(input), self.pool(input)], 1)
        output = self.bn(output)
        return F.relu(output)


class non_bottleneck_1d (nn.Module):
    def __init__(self, chann, dropprob, dilated):
        super().__init__()

        self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)

        self.conv1x3_1 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)

        self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)

        self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation=(dilated, 1))

        self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1*dilated), bias=True, dilation=(1, dilated))

        self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)

        self.dropout = nn.Dropout2d(dropprob)

    def forward(self, input):

        output = self.conv3x1_1(input)
        output = F.relu(output)
        output = self.conv1x3_1(output)
        output = self.bn1(output)
        output = F.relu(output)

        output = self.conv3x1_2(output)
        output = F.relu(output)
        output = self.conv1x3_2(output)
        output = self.bn2(output)

        if (self.dropout.p != 0):
            output = self.dropout(output)

        return F.relu(output+input)


class Encoder(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        chans = 32 if in_channels > 16 else 16
        self.initial_block = DownsamplerBlock(in_channels, chans)

        self.layers = nn.ModuleList()

        self.layers.append(DownsamplerBlock(chans, 64))

        for x in range(0, 5):
            self.layers.append(non_bottleneck_1d(64, 0.03, 1)) 

        self.layers.append(DownsamplerBlock(64, 128))

        for x in range(0, 2):
            self.layers.append(non_bottleneck_1d(128, 0.3, 2))
            self.layers.append(non_bottleneck_1d(128, 0.3, 4))
            self.layers.append(non_bottleneck_1d(128, 0.3, 8))
            self.layers.append(non_bottleneck_1d(128, 0.3, 16))

        #Only in encoder mode:
        self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)

    def forward(self, input, predict=False):
        output = self.initial_block(input)

        for layer in self.layers:
            output = layer(output)

        if predict:
            output = self.output_conv(output)

        return output

class UpsamplerBlock (nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()
        self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
        self.bn = nn.BatchNorm2d(noutput, eps=1e-3)

    def forward(self, input):
        output = self.conv(input)
        output = self.bn(output)
        return F.relu(output)


class Decoder (nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.layer1 = UpsamplerBlock(128, 64)
        self.layer2 = non_bottleneck_1d(64, 0, 1)
        self.layer3 = non_bottleneck_1d(64, 0, 1) # 64x64x304

        self.layer4 = UpsamplerBlock(64, 32)
        self.layer5 = non_bottleneck_1d(32, 0, 1)
        self.layer6 = non_bottleneck_1d(32, 0, 1) # 32x128x608

        self.output_conv = nn.ConvTranspose2d(32, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True)

    def forward(self, input):
        output = input
        output = self.layer1(output)
        output = self.layer2(output)
        output = self.layer3(output)
        em2 = output
        output = self.layer4(output)
        output = self.layer5(output)
        output = self.layer6(output)
        em1 = output

        output = self.output_conv(output)

        return output, em1, em2


class Net(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):  #use encoder to pass pretrained encoder
        super().__init__()
        self.encoder = Encoder(in_channels, out_channels)
        self.decoder = Decoder(out_channels)

    def forward(self, input, only_encode=False):
        if only_encode:
            return self.encoder.forward(input, predict=True)
        else:
            output = self.encoder(input)
            return self.decoder.forward(output)

然后是整体模型,里面包含了沙漏网络。

import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import numpy as np
from .ERFNet import Net

class uncertainty_net(nn.Module):
    def __init__(self, in_channels, out_channels=1, thres=15):
        super(uncertainty_net, self).__init__()
        out_chan = 2

        combine = 'concat'
        self.combine = combine
        self.in_channels = in_channels

        out_channels = 3
        self.depthnet = Net(in_channels=in_channels, out_channels=out_channels)

        local_channels_in = 2 if self.combine == 'concat' else 1
        self.convbnrelu = nn.Sequential(convbn(local_channels_in, 32, 3, 1, 1, 1),
                                        nn.ReLU(inplace=True))
        self.hourglass1 = hourglass_1(32)
        self.hourglass2 = hourglass_2(32)
        self.fuse = nn.Sequential(convbn(32, 32, 3, 1, 1, 1),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(32, out_chan, kernel_size=3, padding=1, stride=1, bias=True))
        self.activation = nn.ReLU(inplace=True)
        self.thres = thres
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, input, epoch=50):
        if self.in_channels > 1:
            rgb_in = input[:, 1:, :, :]
            lidar_in = input[:, 0:1, :, :]
        else:
            lidar_in = input

        # 1. GLOBAL NET
        embedding0, embedding1, embedding2 = self.depthnet(input)

        global_features = embedding0[:, 0:1, :, :]
        precise_depth = embedding0[:, 1:2, :, :]
        conf = embedding0[:, 2:, :, :]

        # 2. Fuse 
        if self.combine == 'concat':
            input = torch.cat((lidar_in, global_features), 1)
        elif self.combine == 'add':
            input = lidar_in + global_features
        elif self.combine == 'mul':
            input = lidar_in * global_features
        elif self.combine == 'sigmoid':
            input = lidar_in * nn.Sigmoid()(global_features)
        else:
            input = lidar_in

        # 3. LOCAL NET
        out = self.convbnrelu(input)
        out1, embedding3, embedding4 = self.hourglass1(out, embedding1, embedding2)
        out1 = out1 + out
        out2 = self.hourglass2(out1, embedding3, embedding4)
        out2 = out2 + out
        out = self.fuse(out2)
        lidar_out = out

        # 4. Late Fusion
        lidar_to_depth, lidar_to_conf = torch.chunk(out, 2, dim=1)
        lidar_to_conf, conf = torch.chunk(self.softmax(torch.cat((lidar_to_conf, conf), 1)), 2, dim=1)
        out = conf * precise_depth + lidar_to_conf * lidar_to_depth

        return out, lidar_out, precise_depth, global_features


def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):

    return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False))
                         # nn.BatchNorm2d(out_planes))
class hourglass_1(nn.Module):
    def __init__(self, channels_in):
        super(hourglass_1, self).__init__()

        self.conv1 = nn.Sequential(convbn(channels_in, channels_in, kernel_size=3, stride=2, pad=1, dilation=1),
                                   nn.ReLU(inplace=True))

        self.conv2 = convbn(channels_in, channels_in, kernel_size=3, stride=1, pad=1, dilation=1)

        self.conv3 = nn.Sequential(convbn(channels_in*2, channels_in*2, kernel_size=3, stride=2, pad=1, dilation=1),
                                   nn.ReLU(inplace=True))

        self.conv4 = nn.Sequential(convbn(channels_in*2, channels_in*2, kernel_size=3, stride=1, pad=1, dilation=1))

        self.conv5 = nn.Sequential(nn.ConvTranspose2d(channels_in*4, channels_in*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                   nn.BatchNorm2d(channels_in*2),
                                   nn.ReLU(inplace=True))

        self.conv6 = nn.Sequential(nn.ConvTranspose2d(channels_in*2, channels_in, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                   nn.BatchNorm2d(channels_in))

    def forward(self, x, em1, em2):
        x = self.conv1(x)
        x = self.conv2(x)
        x = F.relu(x, inplace=True)
        x = torch.cat((x, em1), 1)

        x_prime = self.conv3(x)
        x_prime = self.conv4(x_prime)
        x_prime = F.relu(x_prime, inplace=True)
        x_prime = torch.cat((x_prime, em2), 1)

        out = self.conv5(x_prime)
        out = self.conv6(out)
        return out, x, x_prime
class hourglass_2(nn.Module):
    def __init__(self, channels_in):
        super(hourglass_2, self).__init__()

        self.conv1 = nn.Sequential(convbn(channels_in, channels_in*2, kernel_size=3, stride=2, pad=1, dilation=1),
                                   nn.BatchNorm2d(channels_in*2),
                                   nn.ReLU(inplace=True))

        self.conv2 = convbn(channels_in*2, channels_in*2, kernel_size=3, stride=1, pad=1, dilation=1)

        self.conv3 = nn.Sequential(convbn(channels_in*2, channels_in*2, kernel_size=3, stride=2, pad=1, dilation=1),
                                   nn.BatchNorm2d(channels_in*2),
                                   nn.ReLU(inplace=True))

        self.conv4 = nn.Sequential(convbn(channels_in*2, channels_in*4, kernel_size=3, stride=1, pad=1, dilation=1))

        self.conv5 = nn.Sequential(nn.ConvTranspose2d(channels_in*4, channels_in*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                   nn.BatchNorm2d(channels_in*2),
                                   nn.ReLU(inplace=True))

        self.conv6 = nn.Sequential(nn.ConvTranspose2d(channels_in*2, channels_in, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                   nn.BatchNorm2d(channels_in))

    def forward(self, x, em1, em2):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x + em1
        x = F.relu(x, inplace=True)

        x_prime = self.conv3(x)
        x_prime = self.conv4(x_prime)
        x_prime = x_prime + em2
        x_prime = F.relu(x_prime, inplace=True)

        out = self.conv5(x_prime)
        out = self.conv6(out)

        return out 
if __name__ == '__main__':
    batch_size = 4
    in_channels = 4
    H, W = 256, 1216
    model = uncertainty_net(in_channels).cuda()
    print(model)
    print("Number of parameters in model is {:.3f}M".format(sum(tensor.numel() for tensor in model.parameters())/1e6))
    input = torch.rand((batch_size, in_channels, H, W)).cuda().float()
    out = model(input)
    print(out[0].shape)

首先看一下主网络的forward函数:

in_channels = 4
    H, W = 256, 1216
    model = uncertainty_net(in_channels).cuda()
	因为输入通道为4,经过两个切片,获得rgb图像:[1,3,256,1216]和Lidar图像:[1,1,256,1216]。

然后是global net:
输入经过depthnet产生三个输出:self.depthnet = Net(in_channels=in_channels, out_channels=out_channels),NET函数就是ERFNet,三个输出为decoder的输出。

class Decoder (nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.layer1 = UpsamplerBlock(128, 64)
        self.layer2 = non_bottleneck_1d(64, 0, 1)
        self.layer3 = non_bottleneck_1d(64, 0, 1) # 64x64x304

        self.layer4 = UpsamplerBlock(64, 32)
        self.layer5 = non_bottleneck_1d(32, 0, 1)
        self.layer6 = non_bottleneck_1d(32, 0, 1) # 32x128x608

        self.output_conv = nn.ConvTranspose2d(32, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True)

    def forward(self, input):
        output = input
        output = self.layer1(output)
        output = self.layer2(output)
        output = self.layer3(output)
        em2 = output
        output = self.layer4(output)
        output = self.layer5(output)
        output = self.layer6(output)
        em1 = output
        output = self.output_conv(output)
        return output, em1, em2

input(4,4,256,1216)首先进入ERFNet的encoder内部,经过卷积池化后拼接起来,维度变为(4,16,128,608)。接着经过layers,遍历7次。
在这里插入图片描述
encoder的输出为(4,128,32,152),输入进decoder。经过转置卷积进行上采样和non_bottleneck_1d,输出emb2(4,64,64,304),emb1(4,32,128,608),output(4,3256,1216)。
在这里插入图片描述
然后对output进行切片,global_features和precise_depth 和conf大小都为torch.Size([4, 1, 256, 1216])。
第二步融合:

        # 2. Fuse 
        if self.combine == 'concat':
            input = torch.cat((lidar_in, global_features), 1)
        elif self.combine == 'add':
            input = lidar_in + global_features
        elif self.combine == 'mul':
            input = lidar_in * global_features
        elif self.combine == 'sigmoid':
            input = lidar_in * nn.Sigmoid()(global_features)
        else:
            input = lidar_in

combine=concat,所以将lidar数据和global feature拼接在一起,作为新的input。
第三步输入到local branch

        # 3. LOCAL NET
        out = self.convbnrelu(input)
        out1, embedding3, embedding4 = self.hourglass1(out, embedding1, embedding2)
        out1 = out1 + out
        out2 = self.hourglass2(out1, embedding3, embedding4)
        out2 = out2 + out
        out = self.fuse(out2)
        lidar_out = out

新的input首先经过一个卷积,维度变为32,接着将out, embedding1, embedding2共同输入到hourglass1。这里看一下沙漏网络结构。

class hourglass_1(nn.Module):
    def __init__(self, channels_in):
        super(hourglass_1, self).__init__()

        self.conv1 = nn.Sequential(convbn(channels_in, channels_in, kernel_size=3, stride=2, pad=1, dilation=1),
                                   nn.ReLU(inplace=True))

        self.conv2 = convbn(channels_in, channels_in, kernel_size=3, stride=1, pad=1, dilation=1)

        self.conv3 = nn.Sequential(convbn(channels_in*2, channels_in*2, kernel_size=3, stride=2, pad=1, dilation=1),
                                   nn.ReLU(inplace=True))

        self.conv4 = nn.Sequential(convbn(channels_in*2, channels_in*2, kernel_size=3, stride=1, pad=1, dilation=1))

        self.conv5 = nn.Sequential(nn.ConvTranspose2d(channels_in*4, channels_in*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                   nn.BatchNorm2d(channels_in*2),
                                   nn.ReLU(inplace=True))

        self.conv6 = nn.Sequential(nn.ConvTranspose2d(channels_in*2, channels_in, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                   nn.BatchNorm2d(channels_in))

    def forward(self, x, em1, em2):
        x = self.conv1(x)
        x = self.conv2(x)
        x = F.relu(x, inplace=True)
        x = torch.cat((x, em1), 1)

        x_prime = self.conv3(x)
        x_prime = self.conv4(x_prime)
        x_prime = F.relu(x_prime, inplace=True)
        x_prime = torch.cat((x_prime, em2), 1)
        out = self.conv5(x_prime)
        out = self.conv6(out)
        return out, x, x_prime

在这里插入图片描述
out, x, x_prime输出大小为:(4,32,256,1216),(4,64,128,608),(4,128,64,304)。
进过一个跳连接,out1 = out1 + out。接着生成的三个输出再作为下一个沙漏结构的输入。第二个结构同理,生成的输出再经过一个1x1卷积。lidar_out = out(4,2,256,1216)。
第四步后期融合:

        # 4. Late Fusion
        lidar_to_depth, lidar_to_conf = torch.chunk(out, 2, dim=1)
        lidar_to_conf, conf = torch.chunk(self.softmax(torch.cat((lidar_to_conf, conf), 1)), 2, dim=1)
        out = conf * precise_depth + lidar_to_conf * lidar_to_depth

        return out, lidar_out, precise_depth, global_features

将上一步输出按照维度劈开,lidar_to_conf与全局分支的conf按照维度拼接,经过softmax,再按维度劈开。生成新的置信度图。再与全局分支的depth和局部分支的depth相乘,最后相加得到最终的out。同时也输出lidar_out, precise_depth, global_features。整个框架搭建结束。

--------------------------------------------------------------------分割线----------------------------------------------------------------------------------------
除了整体的框架,接着看一下如何训练
1:定义一些参数,然后初始化优化器和迭代策略。
2:我们选择损失函数:默认为mse。

class MSE_loss(nn.Module):
    def __init__(self):
        super(MSE_loss, self).__init__()

    def forward(self, prediction, gt, epoch=0):
        err = prediction[:,0:1] - gt
        mask = (gt > 0).detach()
        mse_loss = torch.mean((err[mask])**2)#返回所有元素的平均值
        return mse_loss

首先我们取prediction的第一个通道所有的数值,与gt相减,计算出error。
然后判断gt>0,正确的话为true,错误的为flase,大小和gt相同,然后使用detach函数,使其不要更新。
最后计算所有像素差的平均值。
这里看一下err[mask]。这里我随便初始化然后发现err[mask]可以将高维tensor展开为一维。用view函数也可以实现该效果。
在这里插入图片描述
但是如果mask中有False,那么error中对应的值就会删去。用其他的函数就没有这个效果。
在这里插入图片描述
最后MSE即所有的损失会生成一个损失值。
global框架载入权重:

    # Load pretrained state for cityscapes in GLOBAL net
    if args.pretrained and not args.resume:
        if not args.load_external_mod:
            if not args.multi:
                target_state = model.depthnet.state_dict()
            else:
                target_state = model.module.depthnet.state_dict()
            check = torch.load('erfnet_pretrained.pth')
            for name, val in check.items():
                # Exclude multi GPU prefix
                mono_name = name[7:] 
                if mono_name not in target_state:
                     continue
                try:
                    target_state[mono_name].copy_(val)
                except RuntimeError:
                    continue
            print('Successfully loaded pretrained model')
        

如果不使用多卡直接载入,如果使用需要在model后面加上module。
假如不使用,target为depthnet网络的参数。即ERFNet的参数(看第一个和最后一个key)。
在这里插入图片描述
在这里插入图片描述
再查看erfnet_pretrained.pth的参数:只打印了name。
在这里插入图片描述
然后name[7:]是取第七个字符之后的所有字符,即mono_mame为所有不带module的name。
对比是相等的。那么:if-continue是不执行的。接着就执行try语句,那么就不会执行except语句,因为只有try报错才会交给except处理。
在这里插入图片描述
其中try-except语句参考:添加链接描述在这里插入图片描述
mono_name键对应的值就copy给val。将原始的模型参数赋值给预训练模型的参数?(存疑)
开始训练:在一个epoch里面要执行这么多操作。

    for epoch in range(args.start_epoch, args.nepochs):
        print("\n => Start EPOCH {}".format(epoch + 1))
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        print(args.save_path)
        # Adjust learning rate
        if args.lr_policy is not None and args.lr_policy != 'plateau':
            scheduler.step()
            lr = optimizer.param_groups[0]['lr']
            print('lr is set to {}'.format(lr))

        # Define container objects
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        score_train = AverageMeter()
        score_train_1 = AverageMeter()
        metric_train = Metrics(max_depth=args.max_depth, disp=args.use_disp, normal=args.normal)

        # Train model for args.nepochs
        model.train()

        # compute timing
        end = time.time()

        # Load dataset
        for i, (input, gt) in tqdm(enumerate(train_loader)):

            # Time dataloader
            data_time.update(time.time() - end)

            # Put inputs on gpu if possible
            if not args.no_cuda:
                input, gt = input.cuda(), gt.cuda()
            prediction, lidar_out, precise, guide = model(input, epoch)

            loss = criterion_local(prediction, gt)
            loss_lidar = criterion_lidar(lidar_out, gt)
            loss_rgb = criterion_rgb(precise, gt)
            loss_guide = criterion_guide(guide, gt)
            loss = args.wpred*loss + args.wlid*loss_lidar + args.wrgb*loss_rgb + args.wguide*loss_guide

            losses.update(loss.item(), input.size(0))
            metric_train.calculate(prediction[:, 0:1].detach(), gt.detach())
            score_train.update(metric_train.get_metric(args.metric), metric_train.num)
            score_train_1.update(metric_train.get_metric(args.metric_1), metric_train.num)

            # Clip gradients (usefull for instabilities or mistakes in ground truth)
            if args.clip_grad_norm != 0:
                nn.utils.clip_grad_norm(model.parameters(), args.clip_grad_norm)

            # Setup backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Time trainig iteration
            batch_time.update(time.time() - end)
            end = time.time()

            # Print info
            if (i + 1) % args.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Metric {score.val:.4f} ({score.avg:.4f})'.format(
                       epoch+1, i+1, len(train_loader), batch_time=batch_time,
                       loss=losses,
                       score=score_train))

在这里插入图片描述
首先lr_policy不为空,但是lr_policy等于’plateau’。所以直接跳过if语句。
接着跳入AverageMeter函数中:

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

再跳入matric函数:三个参数默认为(85,false,false)。

class Metrics(object):
    def __init__(self, max_depth=85.0, disp=False, normal=False):
        self.rmse, self.mae = 0, 0
        self.num = 0
        self.disp = disp
        self.max_depth = max_depth
        self.min_disp = 1.0/max_depth
        self.normal = normal

    def calculate(self, prediction, gt):
        valid_mask = (gt > 0).detach()

        self.num = valid_mask.sum().item()
        prediction = prediction[valid_mask]
        gt = gt[valid_mask]

        if self.disp:
            prediction = torch.clamp(prediction, min=self.min_disp)
            prediction = 1./prediction
            gt = 1./gt
        if self.normal:
            prediction = prediction * self.max_depth
            gt = gt * self.max_depth
        prediction = torch.clamp(prediction, min=0, max=self.max_depth)#(min=0,max=1/85),input张量每个元素的夹紧到区间 [min,max][min,max],并返回结果到一个新张量

        abs_diff = (prediction - gt).abs()
        self.rmse = torch.sqrt(torch.mean(torch.pow(abs_diff, 2))).item()
        self.mae = abs_diff.mean().item()

    def get_metric(self, metric_name):
        return self.__dict__[metric_name]

然后模型设置为train模式:
在一个epoch中遍历dataloader,将输入和gt输入到cuda中。将输入输送到model中,产生四个输出。然后将四个输出分别和gt计算损失。然后在赋予每个w一个权重。
在这里插入图片描述
在这里插入图片描述
接着损失进行更新,将loss和batch的值传入到losses中的update函数。
在这里插入图片描述
第一次更新,若batch=4:val = loss,sum = loss4, count = 4, avg=4loss/count 。那个avg就是loss,第二次sum = loss_14+loss_24,count=8,avg = loss_14+loss_24/8,总结为总损失除以总的batch数
在这里插入图片描述
接着计算metric_train。输入参数为prediction[:, 0:1].detach(), gt.detach()。对应维度为torch.Size([4, 1, 256, 1216])和(4,1,256,1216)。
计算调用metric的calcute函数:
valid_mask=(4,1,256,1216),由True和false组成,这里默认为true。
num=4x1x256x1216=1245184。
prediction相当于展开大小为1245184。
gt同理大小为1245184。
disp为false,跳过if语句。
norm默认不存在,跳过if语句。
clamp函数将prediction限制到0-85,然后计算预测值与真实值之间的差距。
rmse 和mae公式。
在这里插入图片描述
然后metric_train调用get_metric函数,我们返回metric_train的key:metric_name,对应的value。

    def get_metric(self, metric_name):
        return self.__dict__[metric_name]

在上一步我们计算了metric_train,他是一个字典:
在这里插入图片描述
rmse对应的值就为0.408595,和metric_train的num=1245184一起作为参数的输入进update函数。生成了新的score_train。
在这里插入图片描述
同理生成新的score_train1,惟一的区别是两个评价指标一个是RMSE一个是MAE。
然后清空梯度,损失反向传播,优化器更新。
运行完一个epoch后输出:
训练集的RMSE,是score_train的平均值。
训练集的MAE,是score_train_1的平均值。
将验证集的valid_loader(验证集数据), model, criterion_lidar, criterion_rgb, criterion_local, criterion_guide, epoch输入进验证集的模型中。得到验证集三个输出:score_valid, score_valid_1, losses_valid
在这里插入图片描述
保存模型:

        if total_score < lowest_loss:

            to_save = True
            best_epoch = epoch+1
            lowest_loss = total_score
        save_checkpoint({
            'epoch': epoch + 1,
            'best epoch': best_epoch,
            'arch': args.mod,
            'state_dict': model.state_dict(),
            'loss': lowest_loss,
            'optimizer': optimizer.state_dict()}, to_save, epoch)

看一下validate的代码:和train的代码很像就不一一介绍了

def validate(loader, model, criterion_lidar, criterion_rgb, criterion_local, criterion_guide, epoch=0):
    # batch_time = AverageMeter()
    losses = AverageMeter()
    metric = Metrics(max_depth=args.max_depth, disp=args.use_disp, normal=args.normal)
    score = AverageMeter()
    score_1 = AverageMeter()
    # Evaluate model
    model.eval()
    # Only forward pass, hence no grads needed
    with torch.no_grad():
        # end = time.time()
        for i, (input, gt) in tqdm(enumerate(loader)):
            if not args.no_cuda:
                input, gt = input.cuda(non_blocking=True), gt.cuda(non_blocking=True)
            prediction, lidar_out, precise, guide = model(input, epoch)

            loss = criterion_local(prediction, gt, epoch)
            loss_lidar = criterion_lidar(lidar_out, gt, epoch)
            loss_rgb = criterion_rgb(precise, gt, epoch)
            loss_guide = criterion_guide(guide, gt, epoch)
            loss = args.wpred*loss + args.wlid*loss_lidar + args.wrgb*loss_rgb + args.wguide*loss_guide
            losses.update(loss.item(), input.size(0))

            metric.calculate(prediction[:, 0:1], gt)
            score.update(metric.get_metric(args.metric), metric.num)
            score_1.update(metric.get_metric(args.metric_1), metric.num)

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值