TSPNet代码分析

论文《Realigning Confidence with Temporal Saliency Information for Point-Level Weakly-Supervised Temporal Action Localization》的official code分析

论文解读

代码分析

先看看训练过程,执行main

if __name__ == '__main__':
    exp = Exp()
    if exp.config.mode == 'eval':
        exp.test()
    else:
        exp.train()

先实例化EXP

class Exp(object):
    def __init__(self, exp_type='THUMOS14'):
        self.config = self._get_config(exp_type)
        if self.config.seed != -1:
            self._setup_seed()
        self.device = self._get_device()

    def train(self):
        train_dataset, train_loader = self._get_data(subset='train')
        test_dataset, test_loader = self._get_data(subset='test')

        model = self._get_model().to(self.device)

        criterion = self._get_criterion()
        optimizer = self._get_optimizer(model)

        loader = iter(train_loader)
        for itr in tqdm(range(1, self.config.num_itr + 1), total=self.config.num_itr):
            if (itr - 1) % (len(train_loader) // self.config.batch_size) == 0:
                loader = iter(train_loader)
            train_one_proposal_batch(model, self.device, loader, criterion, optimizer, self.config.batch_size)

            if itr % self.config.update_fre == 0:
                update_label(dataset=train_dataset, dataloader=train_loader, model=model, device=self.device, up_threshold=self.config.up_threshold)

            if itr % 100 == 0:
                test_proposal(self.config, model, self.device, test_loader, itr)

可以看到获取参数,然后根据mode执行train
首先执行self._get_data,即实例化dataset

    def _get_data(self, subset):
        dataset = PTAL_Dataset(
            data_path=self.config.data_path,
            subset=subset,
            modality=self.config.modality,
            num_classes=self.config.num_classes,
            feature_fps=self.config.feature_fps,
            soft_value=self.config.soft_value
        )
class PTAL_Dataset(Dataset):

    def __init__(self,
                 data_path: str,
                 subset: str = 'test',
                 modality: str = 'both',
                 num_classes: int = 20,
                 feature_fps: int = 25,
                 soft_value: float = 0.4
                 ):
        self.data_path = data_path
        self.subset = subset

        self.modality = modality
        self.feature_fps = feature_fps

        self.dataset = self.data_path.split('/')[-1]
        self.cls_dict = json.load(open('./data/dataset_cls_dict.json', 'rb'))[self.dataset]
        self.num_classes = num_classes
        self.soft_value = soft_value
        # Load label files
        self.gt = json.load(open(os.path.join(self.data_path, 'gt.json'), 'rb'))
        self.p_label = pd.read_csv(os.path.join(self.data_path, 'train_df_ts_in_gt.csv')).groupby('video_id')
        self.fps_dict = json.load(open(os.path.join(self.data_path, 'fps.json'), 'rb'))
        self.delta_dict = {
   }
        # Get video names
        self.vid_names = self._get_vidname()

        # Get proposals
        self.proposals, \
        self.proposals_point, \
        self.proposals_center_label, \
        self.proposals_multi_flag, \
        self.proposals_point_id = self._get_proposals()

主要看看_get_proposals()函数,这个函数用于初始化和更新proposals

    def _get_proposals(self, delta_point_dict=None):
        """
        get proposals and generate the center labels from the original points or the updated saliency points
        """
        history_points = []
        proposals_file = json.load(open(f'{
     self.data_path}/LAC_proposal_{
     self.dataset}_{
     self.subset}.json'))[
            'results']
        proposals = {
   }
        proposals_point = {
   }
        proposals_center_label = {
   }
        proposals_multi_flag = {
   }
        proposals_point_id = {
   }
        proposals_mask = {
   }
        t_factor = self.feature_fps / 16.0

        act, bg, multi = 0, 0, 0
        for idx, name in enumerate(self.vi
  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值