DIN网络结构部分代码讲解

承接上篇文章,本篇文章对照论文的网络结构图在源码中找到对应部分并作简单介绍,由于源码较多,不能一一介绍,感兴趣的可以去自行下载。

 

上篇文章地址:(2条消息) DIN:用于群体行为识别的动态时空推理网络_Mr___WQ的博客-CSDN博客https://blog.csdn.net/Mr___WQ/article/details/129131100

源码链接:GitHub - JacobYuan7/DIN-Group-Activity-Recognition-Benchmark: A new codebase for Group Activity Recognition. It contains codes for ICCV 2021 paper: Spatio-Temporal Dynamic Inference Network for Group Activity Recognition and some other methods.https://github.com/JacobYuan7/DIN-Group-Activity-Recognition-Benchmark 话不多说,先上网络结构图。

 本文提出的DIN的基本框架如上图所示:DIN 的输入为一小段视频,将其输入选定的主干网络以提取视觉特征。对于主干网络,作者主要在ResNet-18和VGG-16上进行实验,然后应用RoIAlign提取与边界框对齐的人物特征,将其嵌入到D维空间中。作者首先构建一个初始化的时空图,该时空图的连接为人物特征的时空邻居(空间维度按照人的坐标排序)。在这个初始化的时空图上,作者在定义的交互域内进行动态关系和动态游走预测,得到中心特征各异的交互图(总共T×N个交互图),然后中心特征可以在各自的交互图上进行特征更新。最后,DIN通过全局的时空池化得到视频的特征表示。

项目目录:

backbone文件夹中有backbone.py,其中定义了MyInception_v3、MyVGG16、MyVGG19、MyRes18、MyRes50和MyAlex等结构,可根据配置灵活选用。

base_model.py:里面定义了两个model,分别对应于Volleyball和Collective数据集。这两个model对应于网络结构图中的Spation-Temporal Feature Extraction。此处仅以Volleyball数据集的代码为例,代码如下(里面的内容已经作了注释,不做过多赘述,有理解错误或说错的的地方还请谅解,下同):

class Basenet_volleyball(nn.Module):
    """
    main module of base model for the volleyball
    """
    def __init__(self, cfg):
        super(Basenet_volleyball, self).__init__()
        #cfg的定义在Config.py中,如Config('volleyball')
        self.cfg=cfg
        #NFB:每个box的特征数
        NFB=self.cfg.num_features_boxes
        #D:Embedding的维度,源码中为512,Lite-DIN为128
        D=self.cfg.emb_features
        # K用作ROI_Align剪裁
        K=self.cfg.crop_size[0]
        
        #backbone的选择
        if cfg.backbone=='inv3':
            self.backbone=MyInception_v3(transform_input=False,pretrained=True)
        elif cfg.backbone=='vgg16':
            self.backbone=MyVGG16(pretrained=True)
        elif cfg.backbone=='vgg19':
            self.backbone=MyVGG19(pretrained=True)
        elif cfg.backbone == 'res18':
            self.backbone = MyRes18(pretrained = True)
        else:
            assert False
        
        #通过roi_align进行crop
        self.roi_align=RoIAlign(*self.cfg.crop_size)
        
        #Embedding的全连接
        self.fc_emb = nn.Linear(K*K*D,NFB)
        self.dropout_emb = nn.Dropout(p=self.cfg.train_dropout_prob)
        
        #全连接后的actions和activities
        self.fc_actions=nn.Linear(NFB,self.cfg.num_actions)
        self.fc_activities=nn.Linear(NFB,self.cfg.num_activities)
        
        #权重初始化
        for m in self.modules():
            if isinstance(m,nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)

    #
    #保存模型
    def savemodel(self,filepath):
        state = {
            'backbone_state_dict': self.backbone.state_dict(),
            'fc_emb_state_dict':self.fc_emb.state_dict(),
            'fc_actions_state_dict':self.fc_actions.state_dict(),
            'fc_activities_state_dict':self.fc_activities.state_dict()
        }
        
        torch.save(state, filepath)
        print('model saved to:',filepath)

    #加载模型
    def loadmodel(self,filepath):
        state = torch.load(filepath)
        self.backbone.load_state_dict(state['backbone_state_dict'])
        self.fc_emb.load_state_dict(state['fc_emb_state_dict'])
        self.fc_actions.load_state_dict(state['fc_actions_state_dict'])
        self.fc_activities.load_state_dict(state['fc_activities_state_dict'])
        print('Load model states from: ',filepath)

    #建立网络结构,前向传播,_init_中定义了网络结构里的各个小块,在此建立关系
    def forward(self,batch_data):
        images_in, boxes_in = batch_data
        
        # read config parameters
        #B:batch_size
        B=images_in.shape[0]
        #T:frames,帧数
        T=images_in.shape[1]
        #H:height, W:width,图片高宽
        H, W=self.cfg.image_size
        #需要输出得到的尺寸
        OH, OW=self.cfg.out_size
        #每一帧中框的个数
        N=self.cfg.num_boxes
        #NFB同上
        NFB=self.cfg.num_features_boxes
        
        # Reshape the input data,作shape变换
        images_in_flat=torch.reshape(images_in,(B*T,3,H,W))  #B*T, 3, H, W
        boxes_in_flat=torch.reshape(boxes_in,(B*T*N,4))  #B*T*N, 4

        boxes_idx=[i * torch.ones(N, dtype=torch.int)   for i in range(B*T) ]
        boxes_idx=torch.stack(boxes_idx).to(device=boxes_in.device)  # B*T, N
        boxes_idx_flat=torch.reshape(boxes_idx,(B*T*N,))  #B*T*N,
        
        
        # Use backbone to extract features of images_in
        # Pre-precess first
        images_in_flat=prep_images(images_in_flat)
        
        outputs=self.backbone(images_in_flat)
        
        
        # Build multiscale features
        features_multiscale=[]
        for features in outputs:
            if features.shape[2:4]!=torch.Size([OH,OW]):
                features=F.interpolate(features,size=(OH,OW),mode='bilinear',align_corners=True)
            features_multiscale.append(features)
        
        features_multiscale=torch.cat(features_multiscale,dim=1)  #B*T, D, OH, OW
        
        
        
        # ActNet
        boxes_in_flat.requires_grad=False
        boxes_idx_flat.requires_grad=False
#         features_multiscale.requires_grad=False
        
    
        # RoI Align
        boxes_features=self.roi_align(features_multiscale,
                                            boxes_in_flat,
                                            boxes_idx_flat)  #B*T*N, D, K, K,
        
        
        boxes_features=boxes_features.reshape(B*T*N,-1) # B*T*N, D*K*K
        
            
        # Embedding to hidden state
        boxes_features=self.fc_emb(boxes_features)  # B*T*N, NFB
        boxes_features=F.relu(boxes_features)
        boxes_features=self.dropout_emb(boxes_features)
       
    
        boxes_states=boxes_features.reshape(B,T,N,NFB)
        
        # Predict actions
        boxes_states_flat=boxes_states.reshape(-1,NFB)  #B*T*N, NFB

        actions_scores=self.fc_actions(boxes_states_flat)  #B*T*N, actn_num
        
        
        # Predict activities
        boxes_states_pooled,_=torch.max(boxes_states,dim=2)  #B, T, NFB
        boxes_states_pooled_flat=boxes_states_pooled.reshape(-1,NFB)  #B*T, NFB
        
        activities_scores=self.fc_activities(boxes_states_pooled_flat)  #B*T, acty_num
        
        if T!=1:
            actions_scores=actions_scores.reshape(B,T,N,-1).mean(dim=1).reshape(B*N,-1)
            activities_scores=activities_scores.reshape(B,T,-1).mean(dim=1)
            
        return actions_scores, activities_scores

上面的代码对应网络结构左侧,由于内存限制,整个网络分成两个stage进行训练,以上是第一个stage的结构,训练时用的是train_volleyball_stage1.py,训练完成后保存模型文件以便第二阶段训练时直接加载即可。

第二部分是推理部分,文件嵌套顺序为

infer_model.py->/infer_module/dynamic_infer_module.py

与stage1相比多了inference部分, infer_model.py中定义了第二部分的网络结构,下面的代码是与stage1不同的地方,作者将网络的第二阶段封装在了一个class中,定义在dynamic_infer_module.py。

        if not self.cfg.hierarchical_inference:
            # self.DPI = Dynamic_Person_Inference(
            #     in_dim = in_dim,
            #     person_mat_shape = (10, 12),
            #     stride = cfg.stride,
            #     kernel_size = cfg.ST_kernel_size,
            #     dynamic_sampling=cfg.dynamic_sampling,
            #     sampling_ratio = cfg.sampling_ratio, # [1,2,4]
            #     group = cfg.group,
            #     scale_factor = cfg.scale_factor,
            #     beta_factor = cfg.beta_factor,
            #     parallel_inference = cfg.parallel_inference,
            #     cfg = cfg)
            self.DPI = Multi_Dynamic_Inference(
                in_dim = in_dim,
                person_mat_shape = (10, 12),
                stride = cfg.stride,
                kernel_size = cfg.ST_kernel_size,
                dynamic_sampling=cfg.dynamic_sampling,
                sampling_ratio = cfg.sampling_ratio, # [1,2,4]
                group = cfg.group,
                scale_factor = cfg.scale_factor,
                beta_factor = cfg.beta_factor,
                parallel_inference = cfg.parallel_inference,
                num_DIM = cfg.num_DIM,
                cfg = cfg)
            print_log(cfg.log_path, 'Hierarchical Inference : ' + str(cfg.hierarchical_inference))
        else:
            self.DPI = Hierarchical_Dynamic_Inference(
                in_dim = in_dim,
                person_mat_shape=(10, 12),
                stride=cfg.stride,
                kernel_size=cfg.ST_kernel_size,
                dynamic_sampling=cfg.dynamic_sampling,
                sampling_ratio=cfg.sampling_ratio,  # [1,2,4]
                group=cfg.group,
                scale_factor=cfg.scale_factor,
                beta_factor=cfg.beta_factor,
                parallel_inference=cfg.parallel_inference,
                cfg = cfg,)
            print(cfg.log_path, 'Hierarchical Inference : ' + str(cfg.hierarchical_inference))
        self.dpi_nl = nn.LayerNorm([T, N, in_dim])
        self.dropout_global = nn.Dropout(p=self.cfg.train_dropout_prob)


        # Lite Dynamic inference
        if self.cfg.lite_dim:
            self.point_conv = nn.Conv2d(NFB, in_dim, kernel_size = 1, stride = 1)
            self.point_ln = nn.LayerNorm([T, N, in_dim])
            self.fc_activities = nn.Linear(in_dim, self.cfg.num_activities)
        else:
            self.fc_activities=nn.Linear(NFG, self.cfg.num_activities)

        for m in self.modules():
            if isinstance(m,nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

值得说明的是Lite Dynamic inference是将特征的embed维度从512 ->128,减小了计算量。

进一步探索dynamic_infer_module.py,由于代码过长,这里只贴上DR、DW相关代码,有兴趣的可以参看完整代码,具体细节参照论文。

    def parallel_infer(self, person_features, ratio):
        assert self.dynamic_sampling and self.scale_factor

        # Dynamic affinity infer
        scale = self.scale_conv[str(ratio)](person_features).permute(0, 2, 3, 1) # shape [B, T, N, k2]
        # DR
        scale = F.softmax(scale, dim = -1)
        pos = self._get_plain_pos(ratio, person_features) # [B, T, N, 2*k2]

        pad_ft = self.zero_padding[str(ratio)](person_features).permute(0, 2, 3, 1) # [B, H, W, NFB]
        pad_ft = pad_ft.view(pad_ft.shape[0], -1, pad_ft.shape[-1])
        ft_pos = self._get_ft(pad_ft, pos.long(), ratio)

        ft_infer_scale =  torch.sum(ft_pos * scale.unsqueeze(-1), dim = 3)


        # Dynamic walk infer
        offset = self.p_conv[str(ratio)](person_features).permute(0, 2, 3, 1)  # shape [B, T, N,  2*k2]
        pos = self._get_pos(offset, ratio)  # [B, T, N, 2*k2]

        # Original
        lt = pos.data.floor()
        rb = lt + 1

        # Calclate bilinear coefficient
        # corner point position. lt shape # [B, T, N, 2*k2]
        k2 = self.kernel_size[0]*self.kernel_size[1]
        lt = torch.cat((torch.clamp(lt[..., :k2], 0, self.T + 2 * ratio - 1),
                        torch.clamp(lt[..., k2:], 0, self.N + 2 * ratio - 1)), dim=-1)
        rb = torch.cat((torch.clamp(rb[..., :k2], 0, self.T + 2 * ratio - 1),
                        torch.clamp(rb[..., k2:], 0, self.N + 2 * ratio - 1)), dim=-1)
        lb = torch.cat((rb[..., :k2], lt[..., k2:]), dim=-1)
        rt = torch.cat((lt[..., :k2], rb[..., k2:]), dim=-1)

        # coefficient for cornor point pixel.  coe shape [B, T, N, k2]
        pos = torch.cat((torch.clamp(pos[..., :k2], 0, self.T + 2 * ratio),
                         torch.clamp(pos[..., k2:], 0, self.N + 2 * ratio)), dim=-1)
        coe_lt = (1 - torch.abs(pos[..., :k2] - lt[..., :k2])) * (1 - torch.abs(pos[..., k2:] - lt[..., k2:]))
        coe_rb = (1 - torch.abs(pos[..., :k2] - rb[..., :k2])) * (1 - torch.abs(pos[..., k2:] - rb[..., k2:]))
        coe_lb = (1 - torch.abs(pos[..., :k2] - lb[..., :k2])) * (1 - torch.abs(pos[..., k2:] - lb[..., k2:]))
        coe_rt = (1 - torch.abs(pos[..., :k2] - rt[..., :k2])) * (1 - torch.abs(pos[..., k2:] - rt[..., k2:]))


        # corner point feature.  ft shape [B, T, N, k2, NFB]
        # pad_ft = self.zero_padding[ratio](person_features).permute(0, 2, 3, 1)
        # pad_ft = pad_ft.view(pad_ft.shape[0], -1, pad_ft.shape[-1])
        ft_lt = self._get_ft(pad_ft, lt.long(), ratio)
        ft_rb = self._get_ft(pad_ft, rb.long(), ratio)
        ft_lb = self._get_ft(pad_ft, lb.long(), ratio)
        ft_rt = self._get_ft(pad_ft, rt.long(), ratio)
        ft_infer_walk = ft_lt * coe_lt.unsqueeze(-1) + \
                   ft_rb * coe_rb.unsqueeze(-1) + \
                   ft_lb * coe_lb.unsqueeze(-1) + \
                   ft_rt * coe_rt.unsqueeze(-1)

        ft_infer_walk = torch.mean(ft_infer_walk, dim=3)

        return ft_infer_scale + ft_infer_walk

最近在用实验室机器跑这篇论文的有关代码,backbone选用vgg16,30个epoch时,stage1的best如下图所示:

由于机器内存限制,stage2的代码无法运行,无法进行复现,后续条件允许的情况下,将进行stage2的训练。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值