学习完论文我们回归代码,看一下每一部分是怎么实现的
(4)步态识别论文研读——用于步态识别的分层时空表示学习-CSDN博客
主干网络定义了类——class HSTL
从forward 函数开始看
输入inputs
ipts, labs, _, _, seqL = inputs
输入数据维度n s h w
sils = ipts[0].unsqueeze(1) 增加通道维度
n, _, s, h, w = sils.size()
outs = self.arme1(sils)
self.arme1 = nn.Sequential( BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), nn.LeakyReLU(inplace=True) )
输入数据经过了arme1 这是一个三维卷积 由于out=(in-k+2p)/s+1 带入参数 输出数据 s, h, w维度不变只是通道数改变 in_c[0]=32 所以
outs 输出维度变为 n,32,s,h,w
seqL这里是None 的
astp1 = self.astp1(outs, seqL)
self.astp1 = ASTP(split_param=[64], m=1, in_channels=in_c[0], out_channels=in_c[-1])
数据又经过了一个 self.astp1 ,进入ASTP类,看一下这个类的forward 函数,
x = self.SP1(x, seqL=seqL, options={"dim": 2})[0]
经过了一个池化操作,这个池化就是一个torch.max()
class ASTP(nn.Module):
def __init__(self, split_param, m, in_channels, out_channels, flag=True):
super(ASTP, self).__init__()
self.split_param = split_param
self.m = m
self.hpp = nn.ModuleList([
GeMHPP(bin_num=[1]) for i in range(self.m)])
self.flag = flag
if self.flag:
self.proj = BasicConv2d(in_channels, out_channels, 1, 1, 0)
self.SP1 = PackSequenceWrapper(torch.max)
def forward(self, x, seqL):
# x = self.SP1(x, seqL=seqL, options={"dim": 2})[0]
x = self.SP1(x,2)[0]
if self.flag:
x = self.proj(x)
feat = x.split(self.split_param, 2)
feat = torch.cat([self.hpp[i](_) for i, _ in enumerate(feat)], -1)
return feat
self.SP1 = PackSequenceWrapper(torch.max)
class HSTL(BaseModel):
"""
Hierarchical Spatio-Temporal Feature Learning for Gait Recognition
"""
def __init__(self, *args, **kargs):
super(HSTL, self).__init__(*args, **kargs)
def build_network(self, model_cfg):
in_c = model_cfg['channels']
class_num = model_cfg['class_num']
# For CASIA-B dataset.
self.arme1 = nn.Sequential(
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
stride=(1, 1, 1), padding=(1, 1, 1)),
nn.LeakyReLU(inplace=True)
)
self.astp1 = ASTP(split_param=[64], m=1, in_channels=in_c[0], out_channels=in_c[-1])
self.arme2 = nn.Sequential(
ARME_Conv(in_c[0], in_c[0], split_param=[40, 24], m=2, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
ARME_Conv(in_c[0], in_c[1], split_param=[40, 24], m=2, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
)
self.astp2 = ASTP(split_param=[40, 24], m=2, in_channels=in_c[1], out_channels=in_c[-1])
self.fta = FTA_Block(split_param=[40, 24], m=2, in_channels=in_c[1])
self.astp2_fta = ASTP(split_param=[40, 24], m=2, in_channels=in_c[1], out_channels=in_c[-1])
self.arme3 = nn.Sequential(
ARME_Conv(in_c[1], in_c[2], split_param=[8, 32, 16, 8], m=4, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
ARME_Conv(in_c[2], in_c[2], split_param=[8, 32, 16, 8], m=4, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
)
self.astp3 = ASTP(split_param=[8, 32, 16, 8], m=4, in_channels=in_c[2], out_channels=in_c[-1], flag=False)
# self.astp4 = ASTP(split_param=[1,1,1,1,1,1,1,1,
# 1,1,1,1,1,1,1,1,
# 1,1,1,1,1,1,1,1,
# 1,1,1,1,1,1,1,1,
# 1,1,1,1,1,1,1,1,
# 1,1,1,1,1,1,1,1,
# 1,1,1,1,1,1,1,1,
# 1,1,1,1,1,1,1,1], m=64, in_channels=in_c[2], out_channels=in_c[-1])
self.HPP = GeMHPP()
# separable fully connected layer (SeFC)
self.Head0 = SeparateFCs(73, in_c[-1], in_c[-1])
# batchnorm layer (BN)
self.Bn = nn.BatchNorm1d(in_c[-1])
# separable fully connected layer (SeFC)
self.Head1 = SeparateFCs(73, in_c[-1], class_num)
# Temporal Pooling (TP)
self.TP = PackSequenceWrapper(torch.max)
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
seqL = None if not self.training else seqL
if not self.training and len(labs) != 1:
raise ValueError(
'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs)))
sils = ipts[0].unsqueeze(1)
del ipts
n, _, s, h, w = sils.size()
if s < 3:
repeat = 3 if s == 1 else 2
sils = sils.repeat(1, 1, repeat, 1, 1)
outs = self.arme1(sils)
astp1 = self.astp1(outs, seqL)
outs = self.arme2(outs)
astp2 = self.astp2(outs, seqL)
outs = self.fta(outs)
astp2_fta = self.astp2_fta(outs, seqL)
outs = self.arme3(outs)
astp3 = self.astp3(outs, seqL)
astp4 = self.TP(outs, seqL=seqL, options={"dim": 2})[0] # [n, c, h, w]
astp4 = self.HPP(astp4)
# astp4 = self.astp4(outs, seqL)
outs = torch.cat([astp1,astp2, astp2_fta, astp3, astp4], dim=-1) # [n, c, p]
outs = outs.permute(2, 0, 1).contiguous() # [p, n, c]
gait = self.Head0(outs) # [p, n, c]
gait = gait.permute(1, 2, 0).contiguous() # [n, c, p]
bnft = self.Bn(gait) # [n, c, p]
logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c]
gait = gait.permute(0, 2, 1).contiguous() # [n, p, c]
bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c]
logi = logi.permute(1, 0, 2).contiguous() # [n, p, c]
# print(logi.size())
n, _, s, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': bnft, 'labels': labs},
'softmax': {'logits': logi, 'labels': labs}
},
'visual_summary': {
'image/sils': sils.view(n * s, 1, h, w)
},
'inference_feat': {
'embeddings': bnft
}
}
return retval