import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from pyskl.models.builder import HEADS
from pyskl.models.heads.base import BaseHead
## @HEADS.register_module()
class RGBPoseHead(BaseHead):
"""The classification head for Slowfast.
Args:
num_classes (int): Number of classes to be classified.
in_channels (tuple[int]): Number of channels in input feature.
loss_cls (dict): Config for building loss. Default: dict(type='CrossEntropyLoss').
dropout (float): Probability of dropout layer. Default: 0.5.
init_std (float): Std value for Initiation. Default: 0.01.
kwargs (dict, optional): Any keyword argument to be used to initializ the head.
"""
def __init__(self,
num_classes,
in_channels,
loss_cls=dict(type='CrossEntropyLoss'),
loss_components=['rgb', 'pose'],
loss_weights=1.,
dropout=0.5,
init_std=0.01,
**kwargs):
super().__init__(num_classes, in_channels, loss_cls, **kwargs)
print(type(dropout)) ## <class 'float'>
print(dropout) ## 0.5
if isinstance(dropout, float):
dropout = {'rgb': dropout, 'pose': dropout}
print(type(dropout)) ## <class 'dict'>
print(dropout) ## {'rgb': 0.5, 'pose': 0.5}
assert isinstance(dropout, dict)
self.dropout = dropout
self.init_std = init_std
self.in_channels = in_channels
self.loss_components = loss_components
print(loss_components)
print(loss_weights)
print(type(loss_weights))
print(len(loss_components))
if isinstance(loss_weights, float):
loss_weights = [loss_weights] * len(loss_components) ## 列表乘以长度 len(loss_components) 的意思是,创建一个新的列表,其中包含 len(loss_components) 个相同的 loss_weights 值。
print(loss_weights)
print(type(loss_weights))
assert len(loss_weights) == len(loss_components)
self.loss_weights = loss_weights
self.dropout_rgb = nn.Dropout(p=self.dropout['rgb'])
self.dropout_pose = nn.Dropout(p=self.dropout['pose'])
self.fc_rgb = nn.Linear(in_channels[0], num_classes)
self.fc_pose = nn.Linear(in_channels[1], num_classes)
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
def init_weights(self):
"""Initiate the parameters from scratch."""
normal_init(self.fc_rgb, std=self.init_std)
normal_init(self.fc_pose, std=self.init_std)
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The classification scores for input samples.
"""
x_rgb, x_pose = self.avg_pool(x[0]), self.avg_pool(x[1])
print(x_rgb.shape) ## torch.Size([4, 2048, 1, 1, 1])
print(x_pose.shape) ## torch.Size([4, 512, 1, 1, 1])
x_rgb = x_rgb.view(x_rgb.size(0), -1)
print(x_rgb.shape) ## torch.Size([4, 2048])
x_pose = x_pose.view(x_pose.size(0), -1)
print(x_pose.shape) ## torch.Size([4, 512])
x_rgb = self.dropout_rgb(x_rgb)
print(x_rgb.shape)
x_pose = self.dropout_pose(x_pose)
print(x_pose.shape)
cls_scores = {}
cls_scores['rgb'] = self.fc_rgb(x_rgb)
cls_scores['pose'] = self.fc_pose(x_pose)
return cls_scores
# 图像数据
imgs = torch.randn(4, 2048, 8, 7, 7)
# 热度图数据
heatmap_imgs = torch.randn(4, 512, 32, 7, 7)
c = (imgs, heatmap_imgs)
# 调用 forward 函数
b = RGBPoseHead(num_classes=60,
in_channels=[2048, 512])
print(b.loss_cls) ## CrossEntropyLoss()
print(b.dropout_rgb) ## Dropout(p=0.5, inplace=False)
print(b.dropout_pose) ## Dropout(p=0.5, inplace=False)
print(b.fc_rgb) ## Linear(in_features=2048, out_features=60, bias=True)
print(b.fc_pose) ## Linear(in_features=512, out_features=60, bias=True)
print(b.avg_pool) ## AdaptiveAvgPool3d(output_size=(1, 1, 1))
output = b.forward(c)
print(output)
""" {'rgb': tensor([[-5.7700e-02, -2.3603e-02, 2.5867e-02, 5.2894e-03, 5.3256e-02,
6.0107e-03, -4.9221e-02, 2.6697e-02, -7.0253e-03, 3.8262e-02,
-4.0357e-02, 5.1936e-02, 1.2140e-02, -7.4596e-02, -8.9117e-03,
2.9223e-02, 1.1140e-01, 5.0134e-02, 9.0812e-02, -3.9894e-02,
1.0598e-02, -4.5024e-02, 1.6093e-02, -5.7336e-03, 5.2367e-02,
1.4735e-03, -2.4989e-03, 5.1822e-02, -8.4911e-03, 1.7844e-02,
-1.6949e-02, 7.2522e-02, 5.9389e-02, 1.6918e-02, -4.0477e-02,
8.5199e-02, 2.5547e-02, -9.7707e-03, 6.7338e-04, -1.5181e-02,
-1.5108e-02, 4.3326e-03, -8.3722e-02, -6.2588e-04, 3.5934e-02,
-1.0175e-02, 2.0814e-02, -2.8709e-02, 1.5918e-02, -2.4889e-02,
-6.0910e-02, 3.2593e-02, -3.5340e-02, 1.3019e-02, 3.1479e-02,
5.1577e-02, -3.0354e-02, 5.2726e-03, -2.5036e-02, 2.8191e-02],
[-5.4367e-02, -3.1508e-02, 3.4734e-02, 3.9159e-02, 2.1165e-02,
3.3528e-02, 4.3701e-02, 1.9485e-02, 9.0981e-02, -5.5509e-02,
3.3420e-03, -5.1996e-02, -8.2114e-02, -2.0530e-02, -2.1030e-03,
1.5370e-02, 5.1215e-02, 4.4536e-02, 7.2004e-02, 1.5915e-02,
4.8107e-02, 1.4031e-02, 1.2598e-01, -6.8748e-02, 4.5748e-02,
1.8987e-02, -1.2368e-02, -5.7298e-02, -8.7845e-02, 1.0226e-01,
-1.9828e-02, -3.1821e-02, 5.9785e-02, -2.6085e-03, -4.2438e-02,
7.0515e-02, -1.4831e-02, -5.8752e-03, -1.2927e-02, 2.5082e-02,
7.9372e-02, -4.8104e-02, 2.8397e-02, 4.9140e-03, -6.4510e-02,
-4.9261e-03, -7.6909e-02, 2.2095e-02, 3.3610e-02, -1.3098e-02,
5.1483e-02, -1.8296e-02, -7.1595e-02, 1.8773e-02, 7.0159e-03,
2.3975e-02, -7.3588e-02, -5.0598e-02, -7.4159e-02, 4.2474e-02],
[-2.4720e-02, -1.1364e-01, 9.4980e-03, -7.6735e-02, 1.7628e-02,
9.4609e-03, 6.6283e-02, 8.2620e-03, -3.5163e-02, -5.7601e-03,
2.8374e-02, 7.4265e-05, -4.1197e-03, 1.5634e-02, -3.9373e-02,
7.7164e-03, 3.6293e-02, 1.3336e-02, 3.2416e-02, 1.2017e-02,
2.9062e-02, 2.9109e-02, 1.6097e-02, -2.5773e-02, -4.1540e-02,
-1.4493e-02, 6.5766e-02, -3.3982e-02, 6.7399e-02, -5.4210e-02,
-4.2198e-02, 6.2966e-02, 3.6553e-02, -1.1611e-02, -8.3593e-03,
5.2436e-02, 1.8363e-02, -6.1198e-02, -1.4441e-02, 4.4367e-02,
-2.0231e-03, 6.0588e-02, -9.0269e-03, -2.8804e-02, 5.5918e-03,
-7.9103e-02, 3.8040e-02, -1.8536e-02, 3.3727e-02, 2.5426e-02,
-1.2080e-02, -4.5306e-02, 6.0841e-03, 2.8949e-02, 2.9411e-02,
2.8602e-02, 1.0868e-02, -6.8070e-03, -1.9450e-02, 3.4771e-02],
[ 2.8622e-02, 2.9819e-02, -7.2760e-02, 4.4766e-03, 8.8025e-02,
1.3980e-03, 9.1454e-02, -1.4714e-02, 3.0700e-02, -9.5994e-03,
2.7039e-02, 1.0563e-02, -7.2855e-02, -8.7082e-02, -7.9790e-03,
4.2951e-02, -3.9718e-02, -1.8529e-02, 6.0691e-02, -5.2798e-02,
-4.1907e-03, 9.2491e-03, -6.6405e-02, -5.5496e-02, -1.6518e-02,
-6.3773e-02, 5.8387e-02, 5.5772e-02, -2.0838e-02, -6.4116e-02,
-5.7142e-03, 1.4492e-02, 5.1802e-03, -3.1587e-02, 1.6435e-02,
7.1817e-02, 2.1124e-02, -3.8407e-02, 8.3376e-03, 1.1023e-02,
-7.9004e-02, -8.9177e-03, 1.1500e-01, 1.5106e-02, 7.0290e-02,
-4.7518e-02, -3.9339e-02, 2.4787e-02, -3.3126e-03, 2.3994e-02,
-4.8938e-02, 9.4477e-05, -3.7757e-02, 2.0614e-02, 7.0945e-02,
3.7604e-02, -2.4242e-02, -1.2833e-02, -1.2345e-02, 1.6927e-03]],
grad_fn=<AddmmBackward0>), 'pose': tensor([[ 0.0063, 0.0191, -0.0286, 0.0272, 0.0388, 0.0620, -0.0411, 0.0152,
-0.0681, 0.0059, -0.0172, -0.0333, -0.0083, -0.0468, 0.0164, -0.0014,
0.0225, 0.0174, 0.0322, -0.0267, -0.0582, 0.0344, 0.0541, 0.0589,
-0.0712, -0.0150, -0.0624, -0.0388, -0.0183, -0.0664, -0.0073, -0.0352,
0.0319, 0.0341, 0.0407, -0.0170, -0.0048, -0.0135, -0.0155, 0.0226,
0.0133, 0.0194, -0.0307, -0.0462, -0.0603, -0.0263, -0.0114, 0.0155,
0.0159, 0.0764, 0.0290, 0.0251, 0.0047, -0.0316, 0.0403, 0.0570,
0.0355, -0.0285, 0.0378, 0.0428],
[ 0.0220, 0.0172, 0.0154, 0.0324, 0.0579, 0.0344, -0.0548, 0.0147,
0.0028, 0.0033, -0.0837, -0.0212, -0.0144, -0.0010, 0.0319, 0.0010,
0.0181, 0.0371, 0.0221, 0.0650, 0.0173, -0.0064, 0.0659, 0.0643,
-0.0197, -0.0271, -0.0306, -0.0061, -0.0037, -0.0312, -0.0028, -0.0478,
0.0289, 0.0473, -0.0045, -0.0125, 0.0342, -0.0284, -0.0133, 0.0234,
0.0007, 0.0247, -0.0211, -0.0051, -0.0134, 0.0107, 0.0052, 0.0295,
0.0051, -0.0331, 0.0248, -0.0309, 0.0070, -0.0058, 0.0171, 0.0331,
0.0076, -0.0271, 0.0269, 0.0022],
[ 0.0306, -0.0027, -0.0085, 0.0052, 0.0303, 0.0189, -0.0295, 0.0149,
-0.0540, 0.0114, -0.0615, -0.0369, -0.0447, -0.0347, 0.0619, -0.0249,
0.0481, -0.0096, 0.0308, 0.0316, -0.0282, 0.0057, 0.0393, 0.0352,
-0.0349, -0.0589, -0.0451, -0.0121, -0.0025, -0.0172, -0.0006, -0.0612,
0.0422, -0.0092, 0.0043, 0.0053, 0.0396, -0.0142, 0.0476, 0.0109,
0.0237, 0.0153, -0.0394, -0.0110, -0.0577, -0.0089, 0.0205, -0.0006,
0.0185, 0.0390, 0.0821, 0.0117, 0.0057, -0.0078, 0.0492, 0.0327,
0.0254, -0.0156, 0.0358, 0.0331],
[ 0.0446, 0.0016, -0.0315, 0.0309, 0.0402, 0.0444, -0.0040, 0.0030,
-0.0685, 0.0303, -0.0385, -0.0067, -0.0012, -0.0362, 0.0237, 0.0183,
0.0369, 0.0331, 0.0629, 0.0172, -0.0630, -0.0168, 0.0036, 0.0211,
-0.0514, -0.0300, -0.0381, -0.0291, -0.0081, 0.0039, -0.0086, -0.0288,
0.0615, 0.0371, 0.0080, -0.0003, 0.0159, -0.0186, 0.0374, -0.0034,
-0.0202, 0.0068, -0.0046, -0.0213, 0.0109, 0.0209, -0.0021, -0.0395,
0.0281, 0.0297, 0.0190, 0.0041, -0.0153, 0.0144, -0.0048, 0.0059,
0.0272, -0.0400, 0.0424, 0.0160]], grad_fn=<AddmmBackward0>)}
"""
## 60个类别,batch_size为4
## 所以关注60的含义和4的含义
https://github.com/kennymckormick/pyskl/blob/main/pyskl/models/heads/rgbpose_head.py