HR-Pro代码debug记录之一阶段训练

记录此代码的运作方式,供自己学习使用。
论文里的代码框架
framework
整体框架包含两阶段学习,分别是两个模型,Snippet-model (S-model) 和Instance-model (I-model).

S-model代码结构

class S_Model(nn.Module):
    def __init__(self, args):
        super(S_Model, self).__init__()
        self.feature_dim = args.feature_dim
        self.num_class = args.num_class
        self.r_act = args.r_act
        self.dropout = args.dropout

        self.memory = Reliable_Memory(self.num_class, self.feature_dim)
        self.encoder = Encoder(args)
        self.classifier = nn.Sequential(
            nn.Dropout(self.dropout),
            nn.Conv1d(in_channels=self.feature_dim, out_channels=self.num_class + 1, kernel_size=1, stride=1, padding=0, bias=False)
            )
        self.sigmoid = nn.Sigmoid()
        self.bce_criterion = nn.BCELoss(reduction='none')
        self.lambdas = args.lambdas

主要结构就是self.memory,self.encoder和self.classifier
self.memory结构如下

class Reliable_Memory(nn.Module):
    def __init__(self, num_class, feat_dim):
        super(Reliable_Memory, self).__init__()
        self.num_class = num_class
        self.feat_dim = feat_dim
        self.proto_momentum = 0.001 
        self.proto_num = 1
        self.proto_vectors = torch.nn.Parameter(torch.zeros([self.num_class, self.proto_num, self.feat_dim]), requires_grad=False)

以在TH14数据集上运行为例,
可以看到self.memory初始化了一组shape为(20,1,2048)的可学习参数。
20是TH14的类别数目,2048是I3D特征的维度
self.encoder结构如下

class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        self.dataset = args.dataset
        self.feature_dim = args.feature_dim

        RAB_args = args.RAB_args
        self.RAB = nn.ModuleList([
            Reliabilty_Aware_Block(
                input_dim=self.feature_dim,
                dropout=RAB_args['drop_out'],
                num_heads=RAB_args['num_heads'],
                dim_feedforward=RAB_args['dim_feedforward'])
            for i in range(RAB_args['layer_num'])
        ])

        self.feature_embedding = nn.Sequential(
            nn.Conv1d(in_channels=self.feature_dim, out_channels=self.feature_dim, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

可以看到堆叠了几个Reliabilty_Aware_Block加一个1D卷积
Reliabilty_Aware_Block结构如下

class Reliabilty_Aware_Block(nn.Module):
    def __init__(self, input_dim, dropout, num_heads=8, dim_feedforward=128, pos_embed=False):
        super(Reliabilty_Aware_Block, self).__init__()
        self.conv_query = nn.Conv1d(input_dim, input_dim, kernel_size=1, stride=1, padding=0)
        self.conv_key = nn.Conv1d(input_dim, input_dim, kernel_size=1, stride=1, padding=0)
        self.conv_value = nn.Conv1d(input_dim, input_dim, kernel_size=1, stride=1, padding=0)

        self.self_atten = nn.MultiheadAttention(input_dim, num_heads=num_heads, dropout=0.1)
        self.linear1 = nn.Linear(input_dim, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, input_dim)
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

可以看到实际上就是Transformer Encoder Layer

一阶段训练代码

def main(args):
    # >> Initialize the task
    save_config(args, os.path.join(args.output_path_s1, "config.json"))
    utils.set_seed(args.seed)
    os.environ['CUDA_VIVIBLE_DEVICES'] = args.gpu
    args.device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")
    
    # --------------------------------------------------Snippet-level Optimization-------------------------------------------------------#
    if args.stage == 1:
        model = S_Model(args)
        model = model.to(args.device)
        train_loader = data.DataLoader(dataset(args, phase="train", sample="random", stage=args.stage), 
                                       batch_size=1, shuffle=True, num_workers=args.num_workers)
        test_loader = data.DataLoader(dataset(args, phase="test", sample="random", stage=args.stage),
                                    batch_size=1, shuffle=False, num_workers=args.num_workers)

先看看dataset怎么写的

class dataset(Dataset):
    def __init__(self, args, phase="train", sample="random", stage=1):
        self.args = args
        self.phase = phase
        self.sample = sample
        self.stage = stage
        self.num_segments = args.num_segments
        self.class_name_lst = args.class_name_lst
        self.class_idx_dict = {
   cls: idx for idx, cls in enumerate(self.class_name_lst)}
        self.num_class = args.num_class
        self.t_factor = args.frames_per_sec / args.segment_frames_num
        self.data_path = args.data_path
        self.feature_dir = os.path.join(self.data_path, 'features', self.phase)
        self._prepare_data()
    
    def _prepare_data(self):
        # >> video list
        self.data_list = [item.strip() for item in list(open(os.path.join(self.data_path, "split_{}.txt".format(self.phase))))]
        print("number of {} videos:{}".format(self.phase, len(self.data_list)))
        with open(os.path.join(self.data_path, "gt_full.json")) as f:
            self.gt_dict = json.load(f)["database"]

        # >> video label
        self.vid_labels = {
   }
        for item_name in self.data_list:
            item_anns_list = self.gt_dict[item_name]["annotations"]
            item_label = np.zeros(self.num_class)
            for ann in item_anns_list:
                ann_label = ann["label"]
                item_label[self.class_idx_dict[ann_label]] = 1.0
            self.vid_labels[item_name] = item_label

        # >> point label
        self.point_anno = pd.read_csv(os.path.join(self.data_path, 'point_labels', 'point_gaussian.csv'))

        if self.stage == 2:
            with open(os.path.join(self.args.output_path_s1, 'proposals.json'.format(self.phase)),'r') as f:
                self.proposals_json = json.load(f)
            self.load_proposals()

        # >> ambilist
        if self.args.dataset == "THUMOS14":
            ambilist = './dataset/THUMOS14/Ambiguous_test.txt'
            ambilist = list(open(ambilist, "r"))
            self.ambilist = [a.strip("\n").split(" ") for a in ambilist]

首先加载了视频列表和GT文件。

self.vid_labels
{
   'video_validation_0000051': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.]), 'video_validation_0000052': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.]), 'video_validation_0000053': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.]), 'video_validation_0000054': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.]), 'video_validation_0000055': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
........

获得视频级标签
获取点级标签
对于THUMOS14数据集,加载self.ambilist
继续看train代码

        if args.mode == 'train':
            logger = Logger(args.log_path_s1)
            log_filepath = os.path.join(args.log_path_s1, '{}.score'.format(args.dataset))
            initial_log(log_filepath, args)

            model.memory.init(args, model, train_loader)

主要看model.memory.init

    def init(self, args, net, train_loader):
        print('Memory initialization in progress...')
        with torch.no_grad():
            net.eval()
            pfeat_total = {
   }
            temp_loader = data.DataLoader(train_loader.dataset, batch_size=1
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值