Gaze360代码解读

代码链接 :http://gaze360.csail.mit.edu

论文链接:https://paperswithcode.com/paper/gaze360-physically-unconstrained-gaze

Gaze360模型

注视是自然的连续信号。凝视注视和过渡产生一系列凝视方向。为了利用这一点,论文提出了一个基于视频的凝视跟踪模型使用双向长期短期记忆胶囊(LSTM),它提供了一种对序列进行建模的方法,其中一个元素的输出取决于过去和将来的输入。在该论文中,作者利用7个帧的序列来预测中心帧的视线。注意,仅包括单个中央框架的其他序列长度也是可能的。

上图说明了Gaze360模型的体系结构。卷积神经网络(主干)分别处理每个帧中的头部作物,该神经网络产生具有256维的高级特征。这些特征被馈送到具有两层的双向LSTM,这些LSTM消化前向和后向向量中的序列。最后,将这些向量连接起来并通过一个完全连接的层,以产生两个输出:凝视预测和误差分位数估计。

Gaze360模型代码拆分

GazeLSTM

class GazeLSTM(nn.Module):
    def __init__(self):
        super(GazeLSTM, self).__init__()
        self.img_feature_dim = 256  # the dimension of the CNN feature to represent each frame

        self.base_model = resnet18(pretrained=True)

        self.base_model.fc2 = nn.Linear(1000, self.img_feature_dim)

        self.lstm = nn.LSTM(self.img_feature_dim, self.img_feature_dim,bidirectional=True,num_layers=2,batch_first=True)

        # The linear layer that maps the LSTM with the 3 outputs
        self.last_layer = nn.Linear(2*self.img_feature_dim, 3)


    def forward(self, input):

        base_out = self.base_model(input.view((-1, 3) + input.size()[-2:]))

        base_out = base_out.view(input.size(0),7,self.img_feature_dim)

        lstm_out, _ = self.lstm(base_out)
        lstm_out = lstm_out[:,3,:]
        output = self.last_layer(lstm_out).view(-1,3)


        angular_output = output[:,:2]
        angular_output[:,0:1] = math.pi*nn.Tanh()(angular_output[:,0:1])
        angular_output[:,1:2] = (math.pi/2)*nn.Tanh()(angular_output[:,1:2])

        var = math.pi*nn.Sigmoid()(output[:,2:3])
        var = var.view(-1,1).expand(var.size(0),2)

 首先是model.py中的GazeLSTM部分,首先在初始化函数中定义了图片的特征维度为256,主干网络是resnet18,将输入数据的shape通过view函重塑为(-1,3,input.size()[-1],input.size()[-2])的形状,并输入到resnet1得到base_out,改变其shape为(-1,7,256),传入至双向LSTM中保存t时刻的输出lstm_out[:,3,:],将该输出传入至全连接层并将输出的shape改为(-1,3)。

PinBallLoss

使用神经网络做回归任务,我们使用MSE、MAE作为损失函数,最终得到的输出y通常会被近似为y的期望值,但有些情况下目标值y的空间可能会比较大,只预测一个期望值并不能帮助我们做进一步的决策。

这里介绍一个特殊的损失函数——分位数损失,利用分位数损失我们不需要对数据进行任何先验的处理,就可以轻松做到预测输出y的某一分位数水平值,例如5%分位数或95%分位数,利用这个输出很自然就完成预测输出范围的回归模型。

分位数损失函数的表达式如下图:

 代码中以一个简明的表达方式来表达上式:

class PinBallLoss(nn.Module):
    def __init__(self):
        super(PinBallLoss, self).__init__()
        self.q1 = 0.1
        self.q9 = 1-self.q1

    def forward(self, output_o,target_o,var_o):
        q_10 = target_o-(output_o-var_o)
        q_90 = target_o-(output_o+var_o)

        loss_10 = torch.max(self.q1*q_10, (self.q1-1)*q_10)
        loss_90 = torch.max(self.q9*q_90, (self.q9-1)*q_90)


        loss_10 = torch.mean(loss_10)
        loss_90 = torch.mean(loss_90)

        return loss_10+loss_90

 Gaze360训练函数理解(run.py)

def main():
    global args, best_error

    model_v = GazeLSTM()
    model = torch.nn.DataParallel(model_v).cuda()
    model.cuda()


    cudnn.benchmark = True

    image_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


    train_loader = torch.utils.data.DataLoader(
        ImagerLoader(source_path,train_file,transforms.Compose([
            transforms.RandomResizedCrop(size=224,scale=(0.8,1)),transforms.ToTensor(),image_normalize,
        ])),
        batch_size=batch_size, shuffle=True,
        num_workers=workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        ImagerLoader(source_path,val_file,transforms.Compose([
            transforms.Resize((224,224)),transforms.ToTensor(),image_normalize,
        ])),
        batch_size=batch_size, shuffle=True,
        num_workers=workers, pin_memory=True)



    criterion = PinBallLoss().cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr)

    if test==True:

        test_loader = torch.utils.data.DataLoader(
            ImagerLoader(source_path,test_file,transforms.Compose([
                transforms.Resize((224,224)),transforms.ToTensor(),image_normalize,
            ])),
            batch_size=batch_size, shuffle=True,
            num_workers=workers, pin_memory=True)
        checkpoint = torch.load(checkpoint_test)
        model.load_state_dict(checkpoint['state_dict'])
        angular_error = validate(test_loader, model, criterion)
        print('Angular Error is',angular_error)


    for epoch in range(0, epochs):


        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        angular_error = validate(val_loader, model, criterion)

        # remember best angular error in validation and save checkpoint
        is_best = angular_error < best_error
        best_error = min(angular_error, best_error)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_error,
        }, is_best)

 在main函数中调用预先设计的模型,加载训练数据集和验证数据集,使用PinBallLoss损失函数和Adam优化器。判断是否为测试模式,如果是测试模型还需加载测试数据集。在接下来的for循环中则是对每一个epoch进行一次训练,对验证集进行评估,记住最好的角度错误在验证和保存检查点。最后输出每个epoch的凝视估计和分位数误差估计。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值