PFD_net和TransReID学习——三:PFD网络学习

最近在学习行人重识别方向的《Pose-Guided Feature Disentangling for Occluded Person Re-identification Based on Transformer》文章和源码,这篇文章网络PDF的backbone是采用《TransReID: Transformer-based Object Re-Identification》的主干网络TransReID,所以下面在书写的时候会一起说明。
该篇文章是对TransReID源码中关于PFD主干网络部分,网络图如下,代码位于model/make_pfd.py
在这里插入图片描述


书接上回,先提供PFD网络的整体框架。build_skeleton_transformer类的整体框架如下,重复的网络就省略描述。其中,隐藏层 D = 768 D=768 D=768 N = 256 − 16 + 16 16 ∗ 128 − 16 + 16 16 = 128 N= \frac{256-16+16}{16} * \frac{128-16+16}{16} = 128 N=1625616+161612816+16=128pose_decoder_linear的输入维度之所以是2048,是因为 258 4 ∗ 128 4 = 2048 \frac{258}{4}*\frac{128}{4}=2048 42584128=2048,要与heatmap维度保持一致。最后的分类头输出维度是702,是因为在训练/测试中各有702个ID。

build_skeleton_transformer(
  (decoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  
  (transformerdecoderlayer): TransformerDecoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
    )
    (linear1): Linear(in_features=768, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=768, bias=True)
    (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (dropout3): Dropout(p=0.1, inplace=False)
  )
  (transformerdecoder): TransformerDecoder(
    (layers): ModuleList(
      (0): TransformerDecoderLayer()
      (1): TransformerDecoderLayer()
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  
  (base_vit): TransReID(
    (patch_embed): PatchEmbed_overlap(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate=none)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block()
      ...
      (11): Block()
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (fc): Linear(in_features=768, out_features=1000, bias=True)
  )

  (b2): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): DropPath()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate=none)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  )
  
  (bottleneck): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_decoder): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (non_skt_decoder): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  
  (classifier): Linear(in_features=768, out_features=702, bias=False)
  (classifier_1): Linear(in_features=768, out_features=702, bias=False)
  (classifier_2): Linear(in_features=768, out_features=702, bias=False)
  (classifier_3): Linear(in_features=768, out_features=702, bias=False)
  (classifier_4): Linear(in_features=768, out_features=702, bias=False)
  (classifier_5): Linear(in_features=768, out_features=702, bias=False)
  (classifier_6): Linear(in_features=768, out_features=702, bias=False)
  (classifier_7): Linear(in_features=768, out_features=702, bias=False)
  (classifier_8): Linear(in_features=768, out_features=702, bias=False)
  (classifier_9): Linear(in_features=768, out_features=702, bias=False)
  (classifier_10): Linear(in_features=768, out_features=702, bias=False)
  (classifier_11): Linear(in_features=768, out_features=702, bias=False)
  (classifier_12): Linear(in_features=768, out_features=702, bias=False)
  (classifier_13): Linear(in_features=768, out_features=702, bias=False)
  (classifier_14): Linear(in_features=768, out_features=702, bias=False)
  (classifier_15): Linear(in_features=768, out_features=702, bias=False)
  (classifier_16): Linear(in_features=768, out_features=702, bias=False)
  (classifier_17): Linear(in_features=768, out_features=702, bias=False)
  
  (bottleneck_1): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_2): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_3): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_4): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_5): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_6): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_7): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_8): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_9): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_10): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_11): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_12): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_13): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_14): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_15): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_16): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bottleneck_17): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  
  (classifier_encoder): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder): Linear(in_features=768, out_features=702, bias=False)
  (pose_decoder_linear): Linear(in_features=2048, out_features=768, bias=True)
  (pose_avg): AdaptiveAvgPool2d(output_size=(1, 768))
  (non_parts): AdaptiveAvgPool2d(output_size=(1, 768))
  (decoder_global): AdaptiveAvgPool2d(output_size=(1, 768))
  
  (classifier_decoder_1): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_2): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_3): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_4): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_5): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_6): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_7): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_8): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_9): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_10): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_11): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_12): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_13): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_14): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_15): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_16): Linear(in_features=768, out_features=702, bias=False)
  (classifier_decoder_17): Linear(in_features=768, out_features=702, bias=False)
  (bottleneck_decoder_1): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

在这里插入图片描述

具体的执行过程还是要看forward源码,放在最后。在代码中 f g p f_{gp} fgp是变量sim_feat;而 f g b f_{gb} fgb应该是变量token但是代码中却把他与heat_wt做了个残差操作,变成了global_out_feat这一点在论文中没有提及,此处残差的操作有待实验证明。

In order to integrate the the pose information, we set K = M , which is exactly equal to the number of keypoints. Then, a fully connected layer is applied to heatmaps H to obtain the heatmaps H0 , whose dimension is same as the group part local feature fgp. Next, the heatmaps H0 mutiply fgp element-wisely and obtain the pose-guided feature P = [P1; P2; :::; PM ].

feat = features[:, 0].unsqueeze(1) * heat_wt + features[:, 0].unsqueeze(1)
feat = feat.squeeze(1)
# f_gb feature from encoder
global_out_feat = self.bottleneck(feat) #[bs, 768]

然后是Pose-guided Feature Aggregation(PFA)和PVM模块,这两个模块做的事情类似,都是把part prototype和对应最相似的特征/View相加,得到更能反映部位信息的特征。代码如下:

#  PFA 
sim_decoder = PFA(sim_decoder, pose_align_wt) #[bs 17 768]
#  PVM
decoder_feature, ind = PVM(sim_decoder, last_out) #[bs, num_query, 768]
def PFA(matrix, matrix1):
    '''
    @matrix shape [bs, 17, 768]
    @matrix1 shape [bs, 17, 768]

    '''
    assert matrix.shape[0] == matrix1.shape[0], 'Wrong shape'
    assert matrix.shape[1] == matrix1.shape[1], 'Wrong skt num'

    batch_size = matrix.shape[0] #[bs, 17, 768]

    # skt_num = matrix.shape[1]

    pose_weighted_feat = matrix * matrix1   #[bs, 17, 768]

    final_sim = F.cosine_similarity(matrix.unsqueeze(2), pose_weighted_feat.unsqueeze(1), dim=3) #[bs, 17, x] 

    _, ind = torch.max(final_sim, dim=2)

    sim_match = []
    for i in range(batch_size):
        org_mat = matrix[i] #[17, C]
        sim_mat = pose_weighted_feat[i] #[17, C]
        shuffle_mat = []

        for j in range(ind.shape[1]):
            new = org_mat[j] + sim_mat[ind[i][j]]  #[C]
            new = new.unsqueeze(0)
            shuffle_mat.append(new)

        bs_mat = torch.cat(shuffle_mat, dim=0)

        sim_match.append(bs_mat)
    
    alignment_feat = torch.stack(sim_match, dim=0)   #[bs, 17, 768]?

    return alignment_feat
def PVM(matrix, matrix1):
    '''
    @matrix shape [bs, 17, 768]
    @matrix1 shape [bs, x, 768] 
    '''

    assert matrix.shape[0] == matrix1.shape[0], 'Wrong shape'
    assert matrix.shape[2] == matrix1.shape[2], 'Wrong dimension'

    batch_size = matrix.shape[0] #[bs, 17, 768]
    # skt_num = matrix.shape[1]
    final_sim = F.cosine_similarity(matrix.unsqueeze(2), matrix1.unsqueeze(1), dim=3) #[bs, 17, x] 

    _, ind = torch.max(final_sim, dim=2)    # ind.shape [bs, x]

    
    sim_match = []
    for i in range(batch_size):
        org_mat = matrix[i] #[17, C]
        sim_mat = matrix1[i] #[x, C]
        shuffle_mat = []

        for j in range(ind.shape[1]):
            new = org_mat[ind[i][j]] + sim_mat[j]  #[C]
            new = new.unsqueeze(0)
            shuffle_mat.append(new)

        bs_mat = torch.cat(shuffle_mat, dim=0)

        sim_match.append(bs_mat)
    
    final_feature = torch.stack(sim_match, dim=0)   #[bs, x, 768]?

    return final_feature, ind
    def forward(self, x, label=None, cam_label= None, view_label=None): #ht optinal

        bs, c, h, w = x.shape # [batch, 3, 256, 128]

        # HRNet:
        heatmaps, joints = self.pose.predict(x)
        heatmaps = torch.from_numpy(heatmaps).cuda()    #[bs, 17, 64, 32]

        heatmaps = heatmaps.view(bs, heatmaps.shape[1], -1) # [bs, 17, 2048]

        ttt = heatmaps.cpu().numpy()
        skt_ft = np.zeros((heatmaps.shape[0], heatmaps.shape[1]), dtype=np.float32)

        for i, heatmap in enumerate(ttt):  #[64]
            for j, joint in enumerate(heatmap): #[17]

                if max(joint) < self.skeleton_threshold:
                    skt_ft[i][j] = 1    # Eq 4 in paper

        skt_ft = torch.from_numpy(skt_ft).cuda()    #[64, 17]

        pose_align_wt = self.pose_decoder_linear(heatmaps)  #[bs, 17, 768] FC

        heat_wt = self.pose_avg(heatmaps) #[bs, 1, 768]

        features = self.base_vit(x, cam_label=cam_label, view_label=view_label) # [64, 129, 768] ViT

        # Input of decoder 
        decoder_value = features * heat_wt
        decoder_value = decoder_value.permute(1,0,2)

        # strip 
        feature_length = features.size(1) - 1   #128
        patch_length = feature_length // self.num_query  #128 // 17
        token = features[:, 0:1]
        x = features[:, 1:]
    
        sim_feat = []
        # Encoder group features
        for i in range(16):
            exec('b{}_local = x[:, patch_length*{}:patch_length*{}]'.format(i+1, i, i+1))

            exec('b{}_local_feat = self.b2(torch.cat((token, b{}_local), dim=1))'.format(i+1, i+1))
            # exec('print(b{}_local_feat.shape)'.format(i+1))
            exec('local_feat_{} = b{}_local_feat[:, 0]'.format(i+1, i+1))

            exec('sim_feat.append(local_feat_{})'.format(i+1))

        b17_local = x[:, patch_length*16:]
        b17_local_feat = self.b2(torch.cat((token, b17_local), dim=1))
        local_feat_17 = b17_local_feat[:, 0]
        sim_feat.append(local_feat_17)

        # inference list
        inf_encoder = []
        # BN
        for i in range(17):
            exec('local_feat_{}_bn = self.bottleneck_{}(local_feat_{})'.format(i+1, i+1, i+1))
            exec('inf_encoder.append(local_feat_{}_bn/17)'.format(i+1))

        feat = features[:, 0].unsqueeze(1) * heat_wt + features[:, 0].unsqueeze(1)

        feat = feat.squeeze(1)

        # f_gb feature from encoder
        global_out_feat = self.bottleneck(feat) #[bs, 768]

        # part views
        query_embed = self.query_embed

        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)

        prototype = torch.zeros_like(query_embed)
        
        # part-view based decoder
        out = self.transformerdecoder(prototype, decoder_value, query_pos=query_embed)

        # part view features
        last_out = out.permute(1,0,2)   # [bs, num_query, 768]

        sim_decoder = torch.stack(sim_feat, dim=1)  #[bs, 17, 768]

        #  PFA 
        sim_decoder = PFA(sim_decoder, pose_align_wt) #[bs 17 768]

        #  PVM
        decoder_feature, ind = PVM(sim_decoder, last_out) #[bs num_query 768]

        decoder_gb = self.decoder_global(decoder_feature).squeeze(1)   #[bs, 1, 768]

        # non skt parts 
        out_non_parts = []
        # skt parts 
        out_skt_parts = []

        decoder_skt_feature = []

        decoder_non_feature = []

        for i in range(bs):
            non_skt_feat_list = []
            per_skt_feat_list = []

            skt_feat = skt_ft[i]   #[17]
            # non_zero_skt = torch.nonzero(skt_feat).squeeze(1) #[num]

            skt_part = skt_feat.cpu().numpy()
            skt_ind = np.argwhere(skt_part==0).squeeze(1) #[17-num] numpy type

            for j in range(decoder_feature.shape[1]):

                # version 1 use original heatmap label
                # if skt_feat[skt_ind[i][j]] == 0: 
                #     non_feat = decoder_feature[i, j, :]
                #     non_skt_feat_list.append(non_feat)

                if skt_feat[ind[i][j]] == 1: # version 2 use PVM label
                    non_feat = decoder_feature[i, j, :]
                    non_skt_feat_list.append(non_feat)

                else:
                    skt_based_feat = decoder_feature[i, j, :] #[768]
                    per_skt_feat_list.append(skt_based_feat)
    

            if len(non_skt_feat_list) == 0:
                zero_feature = torch.zeros_like(decoder_gb[i])
                non_skt_feat_list.append(zero_feature)         #TODO:
            non_skt_single = torch.stack(non_skt_feat_list, dim=0).unsqueeze(0)  #[1, len(nonzero), 768]、
            
            decoder_non_feature.append(non_skt_single)
            non_skt_single = self.non_parts(non_skt_single) #[1, 1, 768]
            out_non_parts.append(non_skt_single) # [[1,1,768], [1,1,768], ....] bs length

            if len(per_skt_feat_list) == 0:
                per_skt_feat_list.append(decoder_gb[i])         #TODO:
            skt_single = torch.stack(per_skt_feat_list, dim=0).unsqueeze(0)     #[1, x, 768]

            decoder_skt_feature.append(skt_single)
            skt_single = self.non_parts(skt_single) #[1, 1, 768]
            out_skt_parts.append(skt_single)    # [[1,1,768], [1,1,768], ....] bs length


        last_non_parts = torch.cat(out_non_parts, dim=0)    #[bs, 1, 768]

        last_skt_parts = torch.cat(out_skt_parts, dim=0)    #[bs, 1, 768]

        # output high-confidence keypoint features
        decoder_out = self.bottleneck_decoder(last_skt_parts[:, 0]) #[bs, 768]

        # output non-skt-parts
        non_skt_parts = self.non_skt_decoder(last_non_parts[:, 0]) 

        # TODO:use last out or decoder out ?? 
        out_score = self.classifier_decoder(decoder_out)

        # Only high-confidence guided features are used to compute loss
        decoder_list = []

        # pad zeros for high-confidence guided features to self.num_query
        for i in decoder_skt_feature:
            if i.shape[1] < self.num_query:
                pad = torch.zeros((1,self.num_query-i.shape[1], self.in_planes)).to(i.device)
                pad_feat = torch.cat([i, pad], dim=1)  #[1, num_query, 768]
                decoder_list.append(pad_feat)
            else:
                decoder_list.append(i)


        decoder_lt = torch.cat(decoder_list, dim=0) # [64, self.num_query, 768]

        decoder_feature = decoder_lt


        # decoder parts features
        decoder_feat = [decoder_out]
        decoder_inf = []
        for i in range(self.num_query):
            exec('b{}_deocder_local_feat = decoder_feature[:, {}]'.format(i+1, i))
            exec('decoder_feat.append(b{}_deocder_local_feat)'.format(i+1))
            exec('decoder_inf.append(b{}_deocder_local_feat/self.num_query)'.format(i+1))

        # decoder BN
        for i in range(self.num_query):
            exec('decoder_local_feat_{}_bn = self.bottleneck_decoder_{}(b{}_deocder_local_feat)'.format(i+1, i+1, i+1))

        encoder_feat = [global_out_feat] + sim_feat 

        if self.training:
            # encoder parts
            cls_score = self.classifier_encoder(global_out_feat)

            encoder_score = [cls_score]

            for i in range(17):
                
                exec('cls_score_{} = self.classifier_{}(local_feat_{}_bn)'.format(i+1, i+1, i+1))
                exec('encoder_score.append(cls_score_{})'.format(i+1))

            decoder_score = [out_score]

            # decoder parts
            for i in range(self.num_query):

                exec('decoder_cls_score_{} = self.classifier_decoder_{}(decoder_local_feat_{}_bn)'.format(i+1, i+1, i+1))
                exec('decoder_score.append(decoder_cls_score_{})'.format(i+1))

            return encoder_score, encoder_feat ,decoder_score, decoder_feat, non_skt_parts

        else:
            # Inferece concat
            inf_feat = [global_out_feat] + inf_encoder + [decoder_out] + decoder_inf
            inf_features = torch.cat(inf_feat, dim=1)

            return inf_features

代码确实很长,我们可以先关注返回的encoder_score, encoder_feat ,decoder_score, decoder_feat, non_skt_parts5个变量,训练的过程需要针对这5个变量进行迭代。W

        # output high-confidence keypoint features
        decoder_out = self.bottleneck_decoder(last_skt_parts[:, 0]) #[bs, 768]

        # output non-skt-parts
        non_skt_parts = self.non_skt_decoder(last_non_parts[:, 0]) 

        # TODO:use last out or decoder out ?? 
        out_score = self.classifier_decoder(decoder_out)

        # Only high-confidence guided features are used to compute loss
        decoder_list = []

        # pad zeros for high-confidence guided features to self.num_query
        for i in decoder_skt_feature:
            if i.shape[1] < self.num_query:
                pad = torch.zeros((1,self.num_query-i.shape[1], self.in_planes)).to(i.device)
                pad_feat = torch.cat([i, pad], dim=1)  #[1, num_query, 768]
                decoder_list.append(pad_feat)
            else:
                decoder_list.append(i)


        decoder_lt = torch.cat(decoder_list, dim=0) # [64, self.num_query, 768]

        decoder_feature = decoder_lt


        # decoder parts features
        decoder_feat = [decoder_out]
        decoder_inf = []
        for i in range(self.num_query):
            exec('b{}_deocder_local_feat = decoder_feature[:, {}]'.format(i+1, i))
            exec('decoder_feat.append(b{}_deocder_local_feat)'.format(i+1))
            exec('decoder_inf.append(b{}_deocder_local_feat/self.num_query)'.format(i+1))

        # decoder BN
        for i in range(self.num_query):
            exec('decoder_local_feat_{}_bn = self.bottleneck_decoder_{}(b{}_deocder_local_feat)'.format(i+1, i+1, i+1))

        encoder_feat = [global_out_feat] + sim_feat 

        if self.training:
            # encoder parts
            cls_score = self.classifier_encoder(global_out_feat)

            encoder_score = [cls_score]

            for i in range(17):
                
                exec('cls_score_{} = self.classifier_{}(local_feat_{}_bn)'.format(i+1, i+1, i+1))
                exec('encoder_score.append(cls_score_{})'.format(i+1))

            decoder_score = [out_score]

            # decoder parts
            for i in range(self.num_query):

                exec('decoder_cls_score_{} = self.classifier_decoder_{}(decoder_local_feat_{}_bn)'.format(i+1, i+1, i+1))
                exec('decoder_score.append(decoder_cls_score_{})'.format(i+1))

            return encoder_score, encoder_feat ,decoder_score, decoder_feat, non_skt_parts

(未完待续

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值