pyskl/models/heads/rgbpose_head.py

import torch
import torch.nn as nn
from mmcv.cnn import normal_init

from pyskl.models.builder import HEADS
from pyskl.models.heads.base import BaseHead


## @HEADS.register_module()
class RGBPoseHead(BaseHead):
    """The classification head for Slowfast.

    Args:
        num_classes (int): Number of classes to be classified.
        in_channels (tuple[int]): Number of channels in input feature.
        loss_cls (dict): Config for building loss. Default: dict(type='CrossEntropyLoss').
        dropout (float): Probability of dropout layer. Default: 0.5.
        init_std (float): Std value for Initiation. Default: 0.01.
        kwargs (dict, optional): Any keyword argument to be used to initializ the head.
    """

    def __init__(self,
                 num_classes,
                 in_channels,
                 loss_cls=dict(type='CrossEntropyLoss'),
                 loss_components=['rgb', 'pose'],
                 loss_weights=1.,
                 dropout=0.5,
                 init_std=0.01,
                 **kwargs):

        super().__init__(num_classes, in_channels, loss_cls, **kwargs)
        print(type(dropout)) ## <class 'float'>
        print(dropout)       ## 0.5
        if isinstance(dropout, float):
            dropout = {'rgb': dropout, 'pose': dropout}
        print(type(dropout))  ## <class 'dict'>
        print(dropout)   ## {'rgb': 0.5, 'pose': 0.5}
        
        assert isinstance(dropout, dict)

        self.dropout = dropout
        self.init_std = init_std
        self.in_channels = in_channels

        self.loss_components = loss_components
        print(loss_components)
        print(loss_weights)
        print(type(loss_weights))
        print(len(loss_components))
        if isinstance(loss_weights, float):
            loss_weights = [loss_weights] * len(loss_components)    ## 列表乘以长度 len(loss_components) 的意思是,创建一个新的列表,其中包含 len(loss_components) 个相同的 loss_weights 值。
        print(loss_weights)
        print(type(loss_weights))
        assert len(loss_weights) == len(loss_components)
        self.loss_weights = loss_weights

        self.dropout_rgb = nn.Dropout(p=self.dropout['rgb'])
        self.dropout_pose = nn.Dropout(p=self.dropout['pose'])

        self.fc_rgb = nn.Linear(in_channels[0], num_classes)
        self.fc_pose = nn.Linear(in_channels[1], num_classes)
        self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))

    def init_weights(self):
        """Initiate the parameters from scratch."""
        normal_init(self.fc_rgb, std=self.init_std)
        normal_init(self.fc_pose, std=self.init_std)

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The classification scores for input samples.
        """
        x_rgb, x_pose = self.avg_pool(x[0]), self.avg_pool(x[1])
        print(x_rgb.shape)  ## torch.Size([4, 2048, 1, 1, 1])
        print(x_pose.shape) ## torch.Size([4, 512, 1, 1, 1])
        x_rgb = x_rgb.view(x_rgb.size(0), -1)
        print(x_rgb.shape) ## torch.Size([4, 2048])
        x_pose = x_pose.view(x_pose.size(0), -1)
        print(x_pose.shape) ## torch.Size([4, 512])

        x_rgb = self.dropout_rgb(x_rgb)
        print(x_rgb.shape) 
        x_pose = self.dropout_pose(x_pose)
        print(x_pose.shape)

        cls_scores = {}
        cls_scores['rgb'] = self.fc_rgb(x_rgb)
        cls_scores['pose'] = self.fc_pose(x_pose)

        return cls_scores


# 图像数据
imgs = torch.randn(4, 2048, 8, 7, 7)  
# 热度图数据
heatmap_imgs = torch.randn(4, 512, 32, 7, 7)  

c = (imgs, heatmap_imgs)
# 调用 forward 函数    
b = RGBPoseHead(num_classes=60,
    in_channels=[2048, 512])
print(b.loss_cls)     ## CrossEntropyLoss()
print(b.dropout_rgb)  ## Dropout(p=0.5, inplace=False)
print(b.dropout_pose) ## Dropout(p=0.5, inplace=False)
print(b.fc_rgb)   ## Linear(in_features=2048, out_features=60, bias=True)
print(b.fc_pose)  ## Linear(in_features=512, out_features=60, bias=True)
print(b.avg_pool) ## AdaptiveAvgPool3d(output_size=(1, 1, 1))
output = b.forward(c)
print(output)



"""  {'rgb': tensor([[-5.7700e-02, -2.3603e-02,  2.5867e-02,  5.2894e-03,  5.3256e-02,
          6.0107e-03, -4.9221e-02,  2.6697e-02, -7.0253e-03,  3.8262e-02,
         -4.0357e-02,  5.1936e-02,  1.2140e-02, -7.4596e-02, -8.9117e-03,
          2.9223e-02,  1.1140e-01,  5.0134e-02,  9.0812e-02, -3.9894e-02,
          1.0598e-02, -4.5024e-02,  1.6093e-02, -5.7336e-03,  5.2367e-02,
          1.4735e-03, -2.4989e-03,  5.1822e-02, -8.4911e-03,  1.7844e-02,
         -1.6949e-02,  7.2522e-02,  5.9389e-02,  1.6918e-02, -4.0477e-02,
          8.5199e-02,  2.5547e-02, -9.7707e-03,  6.7338e-04, -1.5181e-02,
         -1.5108e-02,  4.3326e-03, -8.3722e-02, -6.2588e-04,  3.5934e-02,
         -1.0175e-02,  2.0814e-02, -2.8709e-02,  1.5918e-02, -2.4889e-02,
         -6.0910e-02,  3.2593e-02, -3.5340e-02,  1.3019e-02,  3.1479e-02,
          5.1577e-02, -3.0354e-02,  5.2726e-03, -2.5036e-02,  2.8191e-02],
        [-5.4367e-02, -3.1508e-02,  3.4734e-02,  3.9159e-02,  2.1165e-02,
          3.3528e-02,  4.3701e-02,  1.9485e-02,  9.0981e-02, -5.5509e-02,
          3.3420e-03, -5.1996e-02, -8.2114e-02, -2.0530e-02, -2.1030e-03,
          1.5370e-02,  5.1215e-02,  4.4536e-02,  7.2004e-02,  1.5915e-02,
          4.8107e-02,  1.4031e-02,  1.2598e-01, -6.8748e-02,  4.5748e-02,
          1.8987e-02, -1.2368e-02, -5.7298e-02, -8.7845e-02,  1.0226e-01,
         -1.9828e-02, -3.1821e-02,  5.9785e-02, -2.6085e-03, -4.2438e-02,
          7.0515e-02, -1.4831e-02, -5.8752e-03, -1.2927e-02,  2.5082e-02,
          7.9372e-02, -4.8104e-02,  2.8397e-02,  4.9140e-03, -6.4510e-02,
         -4.9261e-03, -7.6909e-02,  2.2095e-02,  3.3610e-02, -1.3098e-02,
          5.1483e-02, -1.8296e-02, -7.1595e-02,  1.8773e-02,  7.0159e-03,
          2.3975e-02, -7.3588e-02, -5.0598e-02, -7.4159e-02,  4.2474e-02],
        [-2.4720e-02, -1.1364e-01,  9.4980e-03, -7.6735e-02,  1.7628e-02,
          9.4609e-03,  6.6283e-02,  8.2620e-03, -3.5163e-02, -5.7601e-03,
          2.8374e-02,  7.4265e-05, -4.1197e-03,  1.5634e-02, -3.9373e-02,
          7.7164e-03,  3.6293e-02,  1.3336e-02,  3.2416e-02,  1.2017e-02,
          2.9062e-02,  2.9109e-02,  1.6097e-02, -2.5773e-02, -4.1540e-02,
         -1.4493e-02,  6.5766e-02, -3.3982e-02,  6.7399e-02, -5.4210e-02,
         -4.2198e-02,  6.2966e-02,  3.6553e-02, -1.1611e-02, -8.3593e-03,
          5.2436e-02,  1.8363e-02, -6.1198e-02, -1.4441e-02,  4.4367e-02,
         -2.0231e-03,  6.0588e-02, -9.0269e-03, -2.8804e-02,  5.5918e-03,
         -7.9103e-02,  3.8040e-02, -1.8536e-02,  3.3727e-02,  2.5426e-02,
         -1.2080e-02, -4.5306e-02,  6.0841e-03,  2.8949e-02,  2.9411e-02,
          2.8602e-02,  1.0868e-02, -6.8070e-03, -1.9450e-02,  3.4771e-02],
        [ 2.8622e-02,  2.9819e-02, -7.2760e-02,  4.4766e-03,  8.8025e-02,
          1.3980e-03,  9.1454e-02, -1.4714e-02,  3.0700e-02, -9.5994e-03,
          2.7039e-02,  1.0563e-02, -7.2855e-02, -8.7082e-02, -7.9790e-03,
          4.2951e-02, -3.9718e-02, -1.8529e-02,  6.0691e-02, -5.2798e-02,
         -4.1907e-03,  9.2491e-03, -6.6405e-02, -5.5496e-02, -1.6518e-02,
         -6.3773e-02,  5.8387e-02,  5.5772e-02, -2.0838e-02, -6.4116e-02,
         -5.7142e-03,  1.4492e-02,  5.1802e-03, -3.1587e-02,  1.6435e-02,
          7.1817e-02,  2.1124e-02, -3.8407e-02,  8.3376e-03,  1.1023e-02,
         -7.9004e-02, -8.9177e-03,  1.1500e-01,  1.5106e-02,  7.0290e-02,
         -4.7518e-02, -3.9339e-02,  2.4787e-02, -3.3126e-03,  2.3994e-02,
         -4.8938e-02,  9.4477e-05, -3.7757e-02,  2.0614e-02,  7.0945e-02,
          3.7604e-02, -2.4242e-02, -1.2833e-02, -1.2345e-02,  1.6927e-03]],
       grad_fn=<AddmmBackward0>), 'pose': tensor([[ 0.0063,  0.0191, -0.0286,  0.0272,  0.0388,  0.0620, -0.0411,  0.0152,
         -0.0681,  0.0059, -0.0172, -0.0333, -0.0083, -0.0468,  0.0164, -0.0014,
          0.0225,  0.0174,  0.0322, -0.0267, -0.0582,  0.0344,  0.0541,  0.0589,
         -0.0712, -0.0150, -0.0624, -0.0388, -0.0183, -0.0664, -0.0073, -0.0352,
          0.0319,  0.0341,  0.0407, -0.0170, -0.0048, -0.0135, -0.0155,  0.0226,
          0.0133,  0.0194, -0.0307, -0.0462, -0.0603, -0.0263, -0.0114,  0.0155,
          0.0159,  0.0764,  0.0290,  0.0251,  0.0047, -0.0316,  0.0403,  0.0570,
          0.0355, -0.0285,  0.0378,  0.0428],
        [ 0.0220,  0.0172,  0.0154,  0.0324,  0.0579,  0.0344, -0.0548,  0.0147,
          0.0028,  0.0033, -0.0837, -0.0212, -0.0144, -0.0010,  0.0319,  0.0010,
          0.0181,  0.0371,  0.0221,  0.0650,  0.0173, -0.0064,  0.0659,  0.0643,
         -0.0197, -0.0271, -0.0306, -0.0061, -0.0037, -0.0312, -0.0028, -0.0478,
          0.0289,  0.0473, -0.0045, -0.0125,  0.0342, -0.0284, -0.0133,  0.0234,
          0.0007,  0.0247, -0.0211, -0.0051, -0.0134,  0.0107,  0.0052,  0.0295,
          0.0051, -0.0331,  0.0248, -0.0309,  0.0070, -0.0058,  0.0171,  0.0331,
          0.0076, -0.0271,  0.0269,  0.0022],
        [ 0.0306, -0.0027, -0.0085,  0.0052,  0.0303,  0.0189, -0.0295,  0.0149,
         -0.0540,  0.0114, -0.0615, -0.0369, -0.0447, -0.0347,  0.0619, -0.0249,
          0.0481, -0.0096,  0.0308,  0.0316, -0.0282,  0.0057,  0.0393,  0.0352,
         -0.0349, -0.0589, -0.0451, -0.0121, -0.0025, -0.0172, -0.0006, -0.0612,
          0.0422, -0.0092,  0.0043,  0.0053,  0.0396, -0.0142,  0.0476,  0.0109,
          0.0237,  0.0153, -0.0394, -0.0110, -0.0577, -0.0089,  0.0205, -0.0006,
          0.0185,  0.0390,  0.0821,  0.0117,  0.0057, -0.0078,  0.0492,  0.0327,
          0.0254, -0.0156,  0.0358,  0.0331],
        [ 0.0446,  0.0016, -0.0315,  0.0309,  0.0402,  0.0444, -0.0040,  0.0030,
         -0.0685,  0.0303, -0.0385, -0.0067, -0.0012, -0.0362,  0.0237,  0.0183,
          0.0369,  0.0331,  0.0629,  0.0172, -0.0630, -0.0168,  0.0036,  0.0211,
         -0.0514, -0.0300, -0.0381, -0.0291, -0.0081,  0.0039, -0.0086, -0.0288,
          0.0615,  0.0371,  0.0080, -0.0003,  0.0159, -0.0186,  0.0374, -0.0034,
         -0.0202,  0.0068, -0.0046, -0.0213,  0.0109,  0.0209, -0.0021, -0.0395,
          0.0281,  0.0297,  0.0190,  0.0041, -0.0153,  0.0144, -0.0048,  0.0059,
          0.0272, -0.0400,  0.0424,  0.0160]], grad_fn=<AddmmBackward0>)}
    
"""    
    
## 60个类别,batch_size为4
## 所以关注60的含义和4的含义

https://github.com/kennymckormick/pyskl/blob/main/pyskl/models/heads/rgbpose_head.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值