(2)HSTL步态识别代码解析超详细版

 学习完论文我们回归代码,看一下每一部分是怎么实现的

(4)步态识别论文研读——用于步态识别的分层时空表示学习-CSDN博客

主干网络定义了类——class HSTL

从forward 函数开始看

输入inputs

ipts, labs, _, _, seqL = inputs

输入数据维度n s h w 

sils = ipts[0].unsqueeze(1) 增加通道维度
n, _, s, h, w = sils.size()
outs = self.arme1(sils)
self.arme1 = nn.Sequential(
    BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
                stride=(1, 1, 1), padding=(1, 1, 1)),
    nn.LeakyReLU(inplace=True)
)

输入数据经过了arme1   这是一个三维卷积  由于out=(in-k+2p)/s+1  带入参数 输出数据 s, h, w维度不变只是通道数改变 in_c[0]=32  所以

outs 输出维度变为 n,32,s,h,w

seqL这里是None 的

astp1 = self.astp1(outs, seqL)
self.astp1 = ASTP(split_param=[64], m=1, in_channels=in_c[0], out_channels=in_c[-1])

数据又经过了一个 self.astp1 ,进入ASTP类,看一下这个类的forward 函数,

x = self.SP1(x, seqL=seqL, options={"dim": 2})[0]

经过了一个池化操作,这个池化就是一个torch.max()

class ASTP(nn.Module):
    def __init__(self, split_param, m, in_channels, out_channels, flag=True):
        super(ASTP, self).__init__()
        self.split_param = split_param
        self.m = m
        self.hpp = nn.ModuleList([
            GeMHPP(bin_num=[1]) for i in range(self.m)])

        self.flag = flag
        if self.flag:
            self.proj = BasicConv2d(in_channels, out_channels, 1, 1, 0)


        self.SP1 = PackSequenceWrapper(torch.max)
    def forward(self, x, seqL):
        # x = self.SP1(x, seqL=seqL, options={"dim": 2})[0]
        x = self.SP1(x,2)[0]
        if self.flag:
            x = self.proj(x)
        feat = x.split(self.split_param, 2)
        feat = torch.cat([self.hpp[i](_) for i, _ in enumerate(feat)], -1)
        return feat

self.SP1 = PackSequenceWrapper(torch.max)
class HSTL(BaseModel):
    """
        Hierarchical Spatio-Temporal Feature Learning for Gait Recognition
    """

    def __init__(self, *args, **kargs):
        super(HSTL, self).__init__(*args, **kargs)

    def build_network(self, model_cfg):
        in_c = model_cfg['channels']
        class_num = model_cfg['class_num']
        # For CASIA-B dataset.
        self.arme1 = nn.Sequential(
            BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
                        stride=(1, 1, 1), padding=(1, 1, 1)),
            nn.LeakyReLU(inplace=True)
        )

        self.astp1 = ASTP(split_param=[64], m=1, in_channels=in_c[0], out_channels=in_c[-1])

        self.arme2 = nn.Sequential(
            ARME_Conv(in_c[0], in_c[0], split_param=[40, 24], m=2, kernel_size=(
                3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
            ARME_Conv(in_c[0], in_c[1], split_param=[40, 24], m=2, kernel_size=(
                3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )

        self.astp2 = ASTP(split_param=[40, 24], m=2, in_channels=in_c[1], out_channels=in_c[-1])

        self.fta = FTA_Block(split_param=[40, 24], m=2, in_channels=in_c[1])

        self.astp2_fta = ASTP(split_param=[40, 24], m=2, in_channels=in_c[1], out_channels=in_c[-1])

        self.arme3 = nn.Sequential(
            ARME_Conv(in_c[1], in_c[2], split_param=[8, 32, 16, 8], m=4, kernel_size=(
                3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
            ARME_Conv(in_c[2], in_c[2], split_param=[8, 32, 16, 8], m=4, kernel_size=(
                3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )

        self.astp3 = ASTP(split_param=[8, 32, 16, 8], m=4, in_channels=in_c[2], out_channels=in_c[-1], flag=False)

        # self.astp4 = ASTP(split_param=[1,1,1,1,1,1,1,1,
        #                                1,1,1,1,1,1,1,1,
        #                                1,1,1,1,1,1,1,1,
        #                                1,1,1,1,1,1,1,1,
        #                                1,1,1,1,1,1,1,1,
        #                                1,1,1,1,1,1,1,1,
        #                                1,1,1,1,1,1,1,1,
        #                                1,1,1,1,1,1,1,1], m=64, in_channels=in_c[2], out_channels=in_c[-1])
        self.HPP = GeMHPP()

        # separable fully connected layer (SeFC)
        self.Head0 = SeparateFCs(73, in_c[-1], in_c[-1])
        # batchnorm layer (BN)
        self.Bn = nn.BatchNorm1d(in_c[-1])
        # separable fully connected layer (SeFC)
        self.Head1 = SeparateFCs(73, in_c[-1], class_num)
        # Temporal Pooling (TP)
        self.TP = PackSequenceWrapper(torch.max)

    def forward(self, inputs):
        ipts, labs, _, _, seqL = inputs
        seqL = None if not self.training else seqL
        if not self.training and len(labs) != 1:
            raise ValueError(
                'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs)))
        sils = ipts[0].unsqueeze(1)
        del ipts
        n, _, s, h, w = sils.size()
        if s < 3:
            repeat = 3 if s == 1 else 2
            sils = sils.repeat(1, 1, repeat, 1, 1)
        outs = self.arme1(sils)
        astp1 = self.astp1(outs, seqL)
        outs = self.arme2(outs)
        astp2 = self.astp2(outs, seqL)
        outs = self.fta(outs)
        astp2_fta = self.astp2_fta(outs, seqL)
        outs = self.arme3(outs)
        astp3 = self.astp3(outs, seqL)
        astp4 = self.TP(outs, seqL=seqL, options={"dim": 2})[0]  # [n, c, h, w]
        astp4 = self.HPP(astp4)
        # astp4 = self.astp4(outs, seqL)
        outs = torch.cat([astp1,astp2, astp2_fta, astp3, astp4], dim=-1) # [n, c, p]
        outs = outs.permute(2, 0, 1).contiguous()  # [p, n, c]
        gait = self.Head0(outs)  # [p, n, c]
        gait = gait.permute(1, 2, 0).contiguous()  # [n, c, p]
        bnft = self.Bn(gait)  # [n, c, p]
        logi = self.Head1(bnft.permute(2, 0, 1).contiguous())  # [p, n, c]

        gait = gait.permute(0, 2, 1).contiguous()  # [n, p, c]
        bnft = bnft.permute(0, 2, 1).contiguous()  # [n, p, c]
        logi = logi.permute(1, 0, 2).contiguous()  # [n, p, c]
        # print(logi.size())

        n, _, s, h, w = sils.size()
        retval = {
            'training_feat': {
                'triplet': {'embeddings': bnft, 'labels': labs},
                'softmax': {'logits': logi, 'labels': labs}
            },
            'visual_summary': {
                'image/sils': sils.view(n * s, 1, h, w)
            },
            'inference_feat': {
                'embeddings': bnft
            }
        }
        return retval

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值