【行为检测】ICCV 2019:Temporal Recurrent Networks for Online Action Detection

这是一篇Online Action Detection的文章,整体框架是lstm的,因为online之前的做法都是只利用历史信息,这篇主要新意在于通过预测未来的信息来帮助分类。Temporal Recurrent Network (TRN)。

由于我刚接触action detection,对于怎么确定事件开头和结束位置最好奇,这篇中这里的思路非常简单,就是做一个多分类,以THUMOS数据集为例,帧率26fps,每段视频6帧,所有的决策都以视频段为元素(0.25s),数据集本身20类,分类器分21类,用一类来表示background非动作。原文:

整体架构图

可以看到,图中左侧就是一个lstm,特征提取选择了经典方法,THUMOS使用了VGG-16 和 two stream (TS) CNN。重点在于右侧的TRN cell,右侧的可以横过来看,输入是大lstm中的隐状态h(文中把大的lstm称作Encoder),以h为输入再经过小的lstm,将输出连接起来构成future信息。

再解释一下就是,endcoder中得到了时间t的信息,那以t的信息为输入,再经过序列lstm,每个输出就可以看作是对未来t+1...t+ld的预测,这些预测再经过一个FC层和 t 时刻的结合起来,作用于encoder的下一时序。

从Loss的角度来说,两部分loss,一部分是Encoder输出和真实类别的loss,另一部分是Decoder输出和真实类别的loss,也就是强制encoder学习到预测未来的信息。

附encoder decoder部分代码,文中参数encoder 序列长64,decoder序列长8:

self.enc_drop = nn.Dropout(self.dropout)
self.enc_cell = nn.LSTMCell(self.fusion_size, self.hidden_size)
self.dec_drop = nn.Dropout(self.dropout)
self.dec_cell = nn.LSTMCell(self.hidden_size, self.hidden_size)

self.classifier = nn.Linear(self.hidden_size, self.num_classes)

def encoder(self, camera_input, sensor_input, future_input, enc_hx, enc_cx):
    fusion_input = self.feature_extractor(camera_input, sensor_input)
    fusion_input = torch.cat((fusion_input, future_input), 1)
    enc_hx, enc_cx = \
           self.enc_cell(self.enc_drop(fusion_input), (enc_hx, enc_cx))
    enc_score = self.classifier(self.enc_drop(enc_hx))
    return enc_hx, enc_cx, enc_score

def decoder(self, fusion_input, dec_hx, dec_cx):
    dec_hx, dec_cx = \
           self.dec_cell(self.dec_drop(fusion_input), (dec_hx, dec_cx))
    dec_score = self.classifier(self.dec_drop(dec_hx))
    return dec_hx, dec_cx, dec_score

代码:https://github.com/xumingze0308/TRN.pytorch

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值