论文《Realigning Confidence with Temporal Saliency Information for Point-Level Weakly-Supervised Temporal Action Localization》的official code分析
论文解读
代码分析
先看看训练过程,执行main
if __name__ == '__main__':
exp = Exp()
if exp.config.mode == 'eval':
exp.test()
else:
exp.train()
先实例化EXP
class Exp(object):
def __init__(self, exp_type='THUMOS14'):
self.config = self._get_config(exp_type)
if self.config.seed != -1:
self._setup_seed()
self.device = self._get_device()
def train(self):
train_dataset, train_loader = self._get_data(subset='train')
test_dataset, test_loader = self._get_data(subset='test')
model = self._get_model().to(self.device)
criterion = self._get_criterion()
optimizer = self._get_optimizer(model)
loader = iter(train_loader)
for itr in tqdm(range(1, self.config.num_itr + 1), total=self.config.num_itr):
if (itr - 1) % (len(train_loader) // self.config.batch_size) == 0:
loader = iter(train_loader)
train_one_proposal_batch(model, self.device, loader, criterion, optimizer, self.config.batch_size)
if itr % self.config.update_fre == 0:
update_label(dataset=train_dataset, dataloader=train_loader, model=model, device=self.device, up_threshold=self.config.up_threshold)
if itr % 100 == 0:
test_proposal(self.config, model, self.device, test_loader, itr)
可以看到获取参数,然后根据mode执行train
首先执行self._get_data,即实例化dataset
def _get_data(self, subset):
dataset = PTAL_Dataset(
data_path=self.config.data_path,
subset=subset,
modality=self.config.modality,
num_classes=self.config.num_classes,
feature_fps=self.config.feature_fps,
soft_value=self.config.soft_value
)
class PTAL_Dataset(Dataset):
def __init__(self,
data_path: str,
subset: str = 'test',
modality: str = 'both',
num_classes: int = 20,
feature_fps: int = 25,
soft_value: float = 0.4
):
self.data_path = data_path
self.subset = subset
self.modality = modality
self.feature_fps = feature_fps
self.dataset = self.data_path.split('/')[-1]
self.cls_dict = json.load(open('./data/dataset_cls_dict.json', 'rb'))[self.dataset]
self.num_classes = num_classes
self.soft_value = soft_value
# Load label files
self.gt = json.load(open(os.path.join(self.data_path, 'gt.json'), 'rb'))
self.p_label = pd.read_csv(os.path.join(self.data_path, 'train_df_ts_in_gt.csv')).groupby('video_id')
self.fps_dict = json.load(open(os.path.join(self.data_path, 'fps.json'), 'rb'))
self.delta_dict = {
}
# Get video names
self.vid_names = self._get_vidname()
# Get proposals
self.proposals, \
self.proposals_point, \
self.proposals_center_label, \
self.proposals_multi_flag, \
self.proposals_point_id = self._get_proposals()
主要看看_get_proposals()函数,这个函数用于初始化和更新proposals
def _get_proposals(self, delta_point_dict=None):
"""
get proposals and generate the center labels from the original points or the updated saliency points
"""
history_points = []
proposals_file = json.load(open(f'{
self.data_path}/LAC_proposal_{
self.dataset}_{
self.subset}.json'))[
'results']
proposals = {
}
proposals_point = {
}
proposals_center_label = {
}
proposals_multi_flag = {
}
proposals_point_id = {
}
proposals_mask = {
}
t_factor = self.feature_fps / 16.0
act, bg, multi = 0, 0, 0
for idx, name in enumerate(self.vi