最近在学习行人重识别方向的《Pose-Guided Feature Disentangling for Occluded Person Re-identification Based on Transformer》文章和源码,这篇文章网络PDF的backbone是采用《TransReID: Transformer-based Object Re-Identification》的主干网络TransReID,所以下面在书写的时候会一起说明。
该篇文章是对TransReID源码中关于PFD主干网络部分,网络图如下,代码位于model/make_pfd.py
书接上回,先提供PFD网络的整体框架。build_skeleton_transformer
类的整体框架如下,重复的网络就省略描述。其中,隐藏层
D
=
768
D=768
D=768,
N
=
256
−
16
+
16
16
∗
128
−
16
+
16
16
=
128
N= \frac{256-16+16}{16} * \frac{128-16+16}{16} = 128
N=16256−16+16∗16128−16+16=128。pose_decoder_linear
的输入维度之所以是2048,是因为
258
4
∗
128
4
=
2048
\frac{258}{4}*\frac{128}{4}=2048
4258∗4128=2048,要与heatmap维度保持一致。最后的分类头输出维度是702,是因为在训练/测试中各有702个ID。
build_skeleton_transformer(
(decoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(transformerdecoderlayer): TransformerDecoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(linear1): Linear(in_features=768, out_features=2048, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=2048, out_features=768, bias=True)
(norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
(dropout3): Dropout(p=0.1, inplace=False)
)
(transformerdecoder): TransformerDecoder(
(layers): ModuleList(
(0): TransformerDecoderLayer()
(1): TransformerDecoderLayer()
)
(norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(base_vit): TransReID(
(patch_embed): PatchEmbed_overlap(
(proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
)
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): ModuleList(
(0): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate=none)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block()
...
(11): Block()
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(fc): Linear(in_features=768, out_features=1000, bias=True)
)
(b2): Sequential(
(0): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate=none)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
)
(bottleneck): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_decoder): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(non_skt_decoder): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(classifier): Linear(in_features=768, out_features=702, bias=False)
(classifier_1): Linear(in_features=768, out_features=702, bias=False)
(classifier_2): Linear(in_features=768, out_features=702, bias=False)
(classifier_3): Linear(in_features=768, out_features=702, bias=False)
(classifier_4): Linear(in_features=768, out_features=702, bias=False)
(classifier_5): Linear(in_features=768, out_features=702, bias=False)
(classifier_6): Linear(in_features=768, out_features=702, bias=False)
(classifier_7): Linear(in_features=768, out_features=702, bias=False)
(classifier_8): Linear(in_features=768, out_features=702, bias=False)
(classifier_9): Linear(in_features=768, out_features=702, bias=False)
(classifier_10): Linear(in_features=768, out_features=702, bias=False)
(classifier_11): Linear(in_features=768, out_features=702, bias=False)
(classifier_12): Linear(in_features=768, out_features=702, bias=False)
(classifier_13): Linear(in_features=768, out_features=702, bias=False)
(classifier_14): Linear(in_features=768, out_features=702, bias=False)
(classifier_15): Linear(in_features=768, out_features=702, bias=False)
(classifier_16): Linear(in_features=768, out_features=702, bias=False)
(classifier_17): Linear(in_features=768, out_features=702, bias=False)
(bottleneck_1): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_2): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_3): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_4): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_5): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_6): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_7): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_8): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_9): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_10): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_11): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_12): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_13): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_14): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_15): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_16): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bottleneck_17): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(classifier_encoder): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder): Linear(in_features=768, out_features=702, bias=False)
(pose_decoder_linear): Linear(in_features=2048, out_features=768, bias=True)
(pose_avg): AdaptiveAvgPool2d(output_size=(1, 768))
(non_parts): AdaptiveAvgPool2d(output_size=(1, 768))
(decoder_global): AdaptiveAvgPool2d(output_size=(1, 768))
(classifier_decoder_1): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_2): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_3): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_4): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_5): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_6): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_7): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_8): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_9): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_10): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_11): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_12): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_13): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_14): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_15): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_16): Linear(in_features=768, out_features=702, bias=False)
(classifier_decoder_17): Linear(in_features=768, out_features=702, bias=False)
(bottleneck_decoder_1): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
具体的执行过程还是要看forward源码,放在最后。在代码中
f
g
p
f_{gp}
fgp是变量sim_feat
;而
f
g
b
f_{gb}
fgb应该是变量token
,但是代码中却把他与heat_wt做了个残差操作,变成了global_out_feat
这一点在论文中没有提及,此处残差的操作有待实验证明。
In order to integrate the the pose information, we set K = M , which is exactly equal to the number of keypoints. Then, a fully connected layer is applied to heatmaps H to obtain the heatmaps H0 , whose dimension is same as the group part local feature fgp. Next, the heatmaps H0 mutiply fgp element-wisely and obtain the pose-guided feature P = [P1; P2; :::; PM ].
feat = features[:, 0].unsqueeze(1) * heat_wt + features[:, 0].unsqueeze(1)
feat = feat.squeeze(1)
# f_gb feature from encoder
global_out_feat = self.bottleneck(feat) #[bs, 768]
然后是Pose-guided Feature Aggregation(PFA)和PVM模块,这两个模块做的事情类似,都是把part prototype和对应最相似的特征/View相加,得到更能反映部位信息的特征。代码如下:
# PFA
sim_decoder = PFA(sim_decoder, pose_align_wt) #[bs 17 768]
# PVM
decoder_feature, ind = PVM(sim_decoder, last_out) #[bs, num_query, 768]
def PFA(matrix, matrix1):
'''
@matrix shape [bs, 17, 768]
@matrix1 shape [bs, 17, 768]
'''
assert matrix.shape[0] == matrix1.shape[0], 'Wrong shape'
assert matrix.shape[1] == matrix1.shape[1], 'Wrong skt num'
batch_size = matrix.shape[0] #[bs, 17, 768]
# skt_num = matrix.shape[1]
pose_weighted_feat = matrix * matrix1 #[bs, 17, 768]
final_sim = F.cosine_similarity(matrix.unsqueeze(2), pose_weighted_feat.unsqueeze(1), dim=3) #[bs, 17, x]
_, ind = torch.max(final_sim, dim=2)
sim_match = []
for i in range(batch_size):
org_mat = matrix[i] #[17, C]
sim_mat = pose_weighted_feat[i] #[17, C]
shuffle_mat = []
for j in range(ind.shape[1]):
new = org_mat[j] + sim_mat[ind[i][j]] #[C]
new = new.unsqueeze(0)
shuffle_mat.append(new)
bs_mat = torch.cat(shuffle_mat, dim=0)
sim_match.append(bs_mat)
alignment_feat = torch.stack(sim_match, dim=0) #[bs, 17, 768]?
return alignment_feat
def PVM(matrix, matrix1):
'''
@matrix shape [bs, 17, 768]
@matrix1 shape [bs, x, 768]
'''
assert matrix.shape[0] == matrix1.shape[0], 'Wrong shape'
assert matrix.shape[2] == matrix1.shape[2], 'Wrong dimension'
batch_size = matrix.shape[0] #[bs, 17, 768]
# skt_num = matrix.shape[1]
final_sim = F.cosine_similarity(matrix.unsqueeze(2), matrix1.unsqueeze(1), dim=3) #[bs, 17, x]
_, ind = torch.max(final_sim, dim=2) # ind.shape [bs, x]
sim_match = []
for i in range(batch_size):
org_mat = matrix[i] #[17, C]
sim_mat = matrix1[i] #[x, C]
shuffle_mat = []
for j in range(ind.shape[1]):
new = org_mat[ind[i][j]] + sim_mat[j] #[C]
new = new.unsqueeze(0)
shuffle_mat.append(new)
bs_mat = torch.cat(shuffle_mat, dim=0)
sim_match.append(bs_mat)
final_feature = torch.stack(sim_match, dim=0) #[bs, x, 768]?
return final_feature, ind
def forward(self, x, label=None, cam_label= None, view_label=None): #ht optinal
bs, c, h, w = x.shape # [batch, 3, 256, 128]
# HRNet:
heatmaps, joints = self.pose.predict(x)
heatmaps = torch.from_numpy(heatmaps).cuda() #[bs, 17, 64, 32]
heatmaps = heatmaps.view(bs, heatmaps.shape[1], -1) # [bs, 17, 2048]
ttt = heatmaps.cpu().numpy()
skt_ft = np.zeros((heatmaps.shape[0], heatmaps.shape[1]), dtype=np.float32)
for i, heatmap in enumerate(ttt): #[64]
for j, joint in enumerate(heatmap): #[17]
if max(joint) < self.skeleton_threshold:
skt_ft[i][j] = 1 # Eq 4 in paper
skt_ft = torch.from_numpy(skt_ft).cuda() #[64, 17]
pose_align_wt = self.pose_decoder_linear(heatmaps) #[bs, 17, 768] FC
heat_wt = self.pose_avg(heatmaps) #[bs, 1, 768]
features = self.base_vit(x, cam_label=cam_label, view_label=view_label) # [64, 129, 768] ViT
# Input of decoder
decoder_value = features * heat_wt
decoder_value = decoder_value.permute(1,0,2)
# strip
feature_length = features.size(1) - 1 #128
patch_length = feature_length // self.num_query #128 // 17
token = features[:, 0:1]
x = features[:, 1:]
sim_feat = []
# Encoder group features
for i in range(16):
exec('b{}_local = x[:, patch_length*{}:patch_length*{}]'.format(i+1, i, i+1))
exec('b{}_local_feat = self.b2(torch.cat((token, b{}_local), dim=1))'.format(i+1, i+1))
# exec('print(b{}_local_feat.shape)'.format(i+1))
exec('local_feat_{} = b{}_local_feat[:, 0]'.format(i+1, i+1))
exec('sim_feat.append(local_feat_{})'.format(i+1))
b17_local = x[:, patch_length*16:]
b17_local_feat = self.b2(torch.cat((token, b17_local), dim=1))
local_feat_17 = b17_local_feat[:, 0]
sim_feat.append(local_feat_17)
# inference list
inf_encoder = []
# BN
for i in range(17):
exec('local_feat_{}_bn = self.bottleneck_{}(local_feat_{})'.format(i+1, i+1, i+1))
exec('inf_encoder.append(local_feat_{}_bn/17)'.format(i+1))
feat = features[:, 0].unsqueeze(1) * heat_wt + features[:, 0].unsqueeze(1)
feat = feat.squeeze(1)
# f_gb feature from encoder
global_out_feat = self.bottleneck(feat) #[bs, 768]
# part views
query_embed = self.query_embed
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
prototype = torch.zeros_like(query_embed)
# part-view based decoder
out = self.transformerdecoder(prototype, decoder_value, query_pos=query_embed)
# part view features
last_out = out.permute(1,0,2) # [bs, num_query, 768]
sim_decoder = torch.stack(sim_feat, dim=1) #[bs, 17, 768]
# PFA
sim_decoder = PFA(sim_decoder, pose_align_wt) #[bs 17 768]
# PVM
decoder_feature, ind = PVM(sim_decoder, last_out) #[bs num_query 768]
decoder_gb = self.decoder_global(decoder_feature).squeeze(1) #[bs, 1, 768]
# non skt parts
out_non_parts = []
# skt parts
out_skt_parts = []
decoder_skt_feature = []
decoder_non_feature = []
for i in range(bs):
non_skt_feat_list = []
per_skt_feat_list = []
skt_feat = skt_ft[i] #[17]
# non_zero_skt = torch.nonzero(skt_feat).squeeze(1) #[num]
skt_part = skt_feat.cpu().numpy()
skt_ind = np.argwhere(skt_part==0).squeeze(1) #[17-num] numpy type
for j in range(decoder_feature.shape[1]):
# version 1 use original heatmap label
# if skt_feat[skt_ind[i][j]] == 0:
# non_feat = decoder_feature[i, j, :]
# non_skt_feat_list.append(non_feat)
if skt_feat[ind[i][j]] == 1: # version 2 use PVM label
non_feat = decoder_feature[i, j, :]
non_skt_feat_list.append(non_feat)
else:
skt_based_feat = decoder_feature[i, j, :] #[768]
per_skt_feat_list.append(skt_based_feat)
if len(non_skt_feat_list) == 0:
zero_feature = torch.zeros_like(decoder_gb[i])
non_skt_feat_list.append(zero_feature) #TODO:
non_skt_single = torch.stack(non_skt_feat_list, dim=0).unsqueeze(0) #[1, len(nonzero), 768]、
decoder_non_feature.append(non_skt_single)
non_skt_single = self.non_parts(non_skt_single) #[1, 1, 768]
out_non_parts.append(non_skt_single) # [[1,1,768], [1,1,768], ....] bs length
if len(per_skt_feat_list) == 0:
per_skt_feat_list.append(decoder_gb[i]) #TODO:
skt_single = torch.stack(per_skt_feat_list, dim=0).unsqueeze(0) #[1, x, 768]
decoder_skt_feature.append(skt_single)
skt_single = self.non_parts(skt_single) #[1, 1, 768]
out_skt_parts.append(skt_single) # [[1,1,768], [1,1,768], ....] bs length
last_non_parts = torch.cat(out_non_parts, dim=0) #[bs, 1, 768]
last_skt_parts = torch.cat(out_skt_parts, dim=0) #[bs, 1, 768]
# output high-confidence keypoint features
decoder_out = self.bottleneck_decoder(last_skt_parts[:, 0]) #[bs, 768]
# output non-skt-parts
non_skt_parts = self.non_skt_decoder(last_non_parts[:, 0])
# TODO:use last out or decoder out ??
out_score = self.classifier_decoder(decoder_out)
# Only high-confidence guided features are used to compute loss
decoder_list = []
# pad zeros for high-confidence guided features to self.num_query
for i in decoder_skt_feature:
if i.shape[1] < self.num_query:
pad = torch.zeros((1,self.num_query-i.shape[1], self.in_planes)).to(i.device)
pad_feat = torch.cat([i, pad], dim=1) #[1, num_query, 768]
decoder_list.append(pad_feat)
else:
decoder_list.append(i)
decoder_lt = torch.cat(decoder_list, dim=0) # [64, self.num_query, 768]
decoder_feature = decoder_lt
# decoder parts features
decoder_feat = [decoder_out]
decoder_inf = []
for i in range(self.num_query):
exec('b{}_deocder_local_feat = decoder_feature[:, {}]'.format(i+1, i))
exec('decoder_feat.append(b{}_deocder_local_feat)'.format(i+1))
exec('decoder_inf.append(b{}_deocder_local_feat/self.num_query)'.format(i+1))
# decoder BN
for i in range(self.num_query):
exec('decoder_local_feat_{}_bn = self.bottleneck_decoder_{}(b{}_deocder_local_feat)'.format(i+1, i+1, i+1))
encoder_feat = [global_out_feat] + sim_feat
if self.training:
# encoder parts
cls_score = self.classifier_encoder(global_out_feat)
encoder_score = [cls_score]
for i in range(17):
exec('cls_score_{} = self.classifier_{}(local_feat_{}_bn)'.format(i+1, i+1, i+1))
exec('encoder_score.append(cls_score_{})'.format(i+1))
decoder_score = [out_score]
# decoder parts
for i in range(self.num_query):
exec('decoder_cls_score_{} = self.classifier_decoder_{}(decoder_local_feat_{}_bn)'.format(i+1, i+1, i+1))
exec('decoder_score.append(decoder_cls_score_{})'.format(i+1))
return encoder_score, encoder_feat ,decoder_score, decoder_feat, non_skt_parts
else:
# Inferece concat
inf_feat = [global_out_feat] + inf_encoder + [decoder_out] + decoder_inf
inf_features = torch.cat(inf_feat, dim=1)
return inf_features
代码确实很长,我们可以先关注返回的encoder_score, encoder_feat ,decoder_score, decoder_feat, non_skt_parts
5个变量,训练的过程需要针对这5个变量进行迭代。W
# output high-confidence keypoint features
decoder_out = self.bottleneck_decoder(last_skt_parts[:, 0]) #[bs, 768]
# output non-skt-parts
non_skt_parts = self.non_skt_decoder(last_non_parts[:, 0])
# TODO:use last out or decoder out ??
out_score = self.classifier_decoder(decoder_out)
# Only high-confidence guided features are used to compute loss
decoder_list = []
# pad zeros for high-confidence guided features to self.num_query
for i in decoder_skt_feature:
if i.shape[1] < self.num_query:
pad = torch.zeros((1,self.num_query-i.shape[1], self.in_planes)).to(i.device)
pad_feat = torch.cat([i, pad], dim=1) #[1, num_query, 768]
decoder_list.append(pad_feat)
else:
decoder_list.append(i)
decoder_lt = torch.cat(decoder_list, dim=0) # [64, self.num_query, 768]
decoder_feature = decoder_lt
# decoder parts features
decoder_feat = [decoder_out]
decoder_inf = []
for i in range(self.num_query):
exec('b{}_deocder_local_feat = decoder_feature[:, {}]'.format(i+1, i))
exec('decoder_feat.append(b{}_deocder_local_feat)'.format(i+1))
exec('decoder_inf.append(b{}_deocder_local_feat/self.num_query)'.format(i+1))
# decoder BN
for i in range(self.num_query):
exec('decoder_local_feat_{}_bn = self.bottleneck_decoder_{}(b{}_deocder_local_feat)'.format(i+1, i+1, i+1))
encoder_feat = [global_out_feat] + sim_feat
if self.training:
# encoder parts
cls_score = self.classifier_encoder(global_out_feat)
encoder_score = [cls_score]
for i in range(17):
exec('cls_score_{} = self.classifier_{}(local_feat_{}_bn)'.format(i+1, i+1, i+1))
exec('encoder_score.append(cls_score_{})'.format(i+1))
decoder_score = [out_score]
# decoder parts
for i in range(self.num_query):
exec('decoder_cls_score_{} = self.classifier_decoder_{}(decoder_local_feat_{}_bn)'.format(i+1, i+1, i+1))
exec('decoder_score.append(decoder_cls_score_{})'.format(i+1))
return encoder_score, encoder_feat ,decoder_score, decoder_feat, non_skt_parts
(未完待续