dict(type=‘FormatShape‘, input_format=‘NCTHW‘)格式

 demo_skeleton

 demo/demo_skeleton.py修改如下:

 
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import cv2
import mmcv
import numpy as np
import os
import os.path as osp
import shutil
import torch
import warnings
from scipy.optimize import linear_sum_assignment

from pyskl.apis import inference_recognizer, init_recognizer

try:
    from mmdet.apis import inference_detector, init_detector
except (ImportError, ModuleNotFoundError):
    def inference_detector(*args, **kwargs):
        pass

    def init_detector(*args, **kwargs):
        pass
    warnings.warn(
        'Failed to import `inference_detector` and `init_detector` from `mmdet.apis`. '
        'Make sure you can successfully import these if you want to use related features. '
    )

try:
    from mmpose.apis import inference_top_down_pose_model, init_pose_model, vis_pose_result
except (ImportError, ModuleNotFoundError):
    def init_pose_model(*args, **kwargs):
        pass

    def inference_top_down_pose_model(*args, **kwargs):
        pass

    def vis_pose_result(*args, **kwargs):
        pass

    warnings.warn(
        'Failed to import `init_pose_model`, `inference_top_down_pose_model`, `vis_pose_result` from '
        '`mmpose.apis`. Make sure you can successfully import these if you want to use related features. '
    )


try:
    import moviepy.editor as mpy
except ImportError:
    raise ImportError('Please install moviepy to enable output file')

FONTFACE = cv2.FONT_HERSHEY_DUPLEX
FONTSCALE = 0.75
FONTCOLOR = (255, 255, 255)  # BGR, white
THICKNESS = 1
LINETYPE = 1


def parse_args():
    parser = argparse.ArgumentParser(description='PoseC3D demo')
    parser.add_argument('video', help='video file/url')
    parser.add_argument('out_filename', help='output filename')
    parser.add_argument(
        '--config',
        default='configs/rgbpose_conv3d/rgb_only.py',
        help='skeleton action recognition config file path')
    parser.add_argument(
        '--checkpoint',
        default='https://download.openmmlab.com/mmaction/pyskl/ckpt/rgbpose_conv3d/rgb_only.pth',
        help='skeleton action recognition checkpoint file/url')
    parser.add_argument(
        '--det-config',
        default='demo/faster_rcnn_r50_fpn_1x_coco-person.py',
        help='human detection config file path (from mmdet)')
    parser.add_argument(
        '--det-checkpoint',
        default=('https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco-person/'
                 'faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth'),
        help='human detection checkpoint file/url')
    parser.add_argument(
        '--pose-config',
        default='demo/hrnet_w32_coco_256x192.py',
        help='human pose estimation config file path (from mmpose)')
    parser.add_argument(
        '--pose-checkpoint',
        default='https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth',
        help='human pose estimation checkpoint file/url')
    parser.add_argument(
        '--det-score-thr',
        type=float,
        default=0.9,
        help='the threshold of human detection score')
    parser.add_argument(
        '--label-map',
        default='tools/data/label_map/nturgbd_120.txt',
        help='label map file')
    parser.add_argument(
        '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
    parser.add_argument(
        '--short-side',
        type=int,
        default=480,
        help='specify the short-side length of the image')
    args = parser.parse_args()
    print('output filename', args.out_filename)
    return args


def frame_extraction(video_path, short_side):
    """Extract frames given video_path.

    Args:
        video_path (str): The video_path.
    """
    # Load the video, extract frames into ./tmp/video_name
    target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0]))
    os.makedirs(target_dir, exist_ok=True)
    # Should be able to handle videos up to several hours
    frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
    vid = cv2.VideoCapture(video_path)
    frames = []
    frame_paths = []
    flag, frame = vid.read()
    cnt = 0
    new_h, new_w = None, None
    while flag:
        if new_h is None:
            h, w, _ = frame.shape
            new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf))

        frame = mmcv.imresize(frame, (new_w, new_h))

        frames.append(frame)
        frame_path = frame_tmpl.format(cnt + 1)
        frame_paths.append(frame_path)

        cv2.imwrite(frame_path, frame)
        cnt += 1
        flag, frame = vid.read()

    return frame_paths, frames


def detection_inference(args, frame_paths):
    """Detect human boxes given frame paths.

    Args:
        args (argparse.Namespace): The arguments.
        frame_paths (list[str]): The paths of frames to do detection inference.

    Returns:
        list[np.ndarray]: The human detection results.
    """
    model = init_detector(args.det_config, args.det_checkpoint, args.device)
    assert model is not None, ('Failed to build the detection model. Check if you have installed mmcv-full properly. '
                               'You should first install mmcv-full successfully, then install mmdet, mmpose. ')
    assert model.CLASSES[0] == 'person', 'We require you to use a detector trained on COCO'
    results = []
    print('Performing Human Detection for each frame')
    prog_bar = mmcv.ProgressBar(len(frame_paths))
    for frame_path in frame_paths:
        result = inference_detector(model, frame_path)
        # We only keep human detections with score larger than det_score_thr
        result = result[0][result[0][:, 4] >= args.det_score_thr]
        results.append(result)
        prog_bar.update()
    return results


def pose_inference(args, frame_paths, det_results):
    model = init_pose_model(args.pose_config, args.pose_checkpoint,
                            args.device)
    ret = []
    print('Performing Human Pose Estimation for each frame')
    prog_bar = mmcv.ProgressBar(len(frame_paths))
    for f, d in zip(frame_paths, det_results):
        # Align input format
        d = [dict(bbox=x) for x in list(d)]
        pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
        ret.append(pose)
        prog_bar.update()
    return ret


def dist_ske(ske1, ske2):
    dist = np.linalg.norm(ske1[:, :2] - ske2[:, :2], axis=1) * 2
    diff = np.abs(ske1[:, 2] - ske2[:, 2])
    return np.sum(np.maximum(dist, diff))


def pose_tracking(pose_results, max_tracks=2, thre=30):
    tracks, num_tracks = [], 0
    num_joints = None
    for idx, poses in enumerate(pose_results):
        if len(poses) == 0:
            continue
        if num_joints is None:
            num_joints = poses[0].shape[0]
        track_proposals = [t for t in tracks if t['data'][-1][0] > idx - thre]
        n, m = len(track_proposals), len(poses)
        scores = np.zeros((n, m))

        for i in range(n):
            for j in range(m):
                scores[i][j] = dist_ske(track_proposals[i]['data'][-1][1], poses[j])

        row, col = linear_sum_assignment(scores)
        for r, c in zip(row, col):
            track_proposals[r]['data'].append((idx, poses[c]))
        if m > n:
            for j in range(m):
                if j not in col:
                    num_tracks += 1
                    new_track = dict(data=[])
                    new_track['track_id'] = num_tracks
                    new_track['data'] = [(idx, poses[j])]
                    tracks.append(new_track)
    if num_joints is None:
        return None, None
    tracks.sort(key=lambda x: -len(x['data']))
    result = np.zeros((max_tracks, len(pose_results), num_joints, 3), dtype=np.float16)
    for i, track in enumerate(tracks[:max_tracks]):
        for item in track['data']:
            idx, pose = item
            result[i, idx] = pose
    return result[..., :2], result[..., 2]


def main():
    args = parse_args()

    frame_paths, original_frames = frame_extraction(args.video,
                                                    args.short_side)
    num_frame = len(frame_paths)
    h, w, _ = original_frames[0].shape

    config = mmcv.Config.fromfile(args.config)
    config.data.test.pipeline = [x for x in config.data.test.pipeline if x['type'] != 'DecompressPose']
    # Are we using GCN for Infernece?
    GCN_flag = 'GCN' in config.model.type
    GCN_nperson = None
    if GCN_flag:
        format_op = [op for op in config.data.test.pipeline if op['type'] == 'FormatGCNInput'][0]
        # We will set the default value of GCN_nperson to 2, which is
        # the default arg of FormatGCNInput
        GCN_nperson = format_op.get('num_person', 2)

    model = init_recognizer(config, args.checkpoint, args.device)

    # Load label_map
    label_map = [x.strip() for x in open(args.label_map).readlines()]

    # Get Human detection results
    det_results = detection_inference(args, frame_paths)
    torch.cuda.empty_cache()

    pose_results = pose_inference(args, frame_paths, det_results)
    torch.cuda.empty_cache()

    fake_anno = dict(
        frame_dir='',
        label=-1,
        img_shape=(h, w),
        original_shape=(h, w),
        start_index=0,
        modality='Pose',
        total_frames=num_frame,
        test_mode = True
    )
    print('fake_anno', fake_anno)
    if not fake_anno['frame_dir']:
        fake_anno['frame_dir'] = args.video
    print(fake_anno['frame_dir'])
    if GCN_flag:
        # We will keep at most `GCN_nperson` persons per frame.
        tracking_inputs = [[pose['keypoints'] for pose in poses] for poses in pose_results]
        keypoint, keypoint_score = pose_tracking(tracking_inputs, max_tracks=GCN_nperson)
        fake_anno['keypoint'] = keypoint
        fake_anno['keypoint_score'] = keypoint_score
    else:
        num_person = max([len(x) for x in pose_results])
        # Current PoseC3D models are trained on COCO-keypoints (17 keypoints)
        num_keypoint = 17
        keypoint = np.zeros((num_person, num_frame, num_keypoint, 2),
                            dtype=np.float16)
        keypoint_score = np.zeros((num_person, num_frame, num_keypoint),
                                  dtype=np.float16)
        for i, poses in enumerate(pose_results):
            for j, pose in enumerate(poses):
                pose = pose['keypoints']
                keypoint[j, i] = pose[:, :2]
                keypoint_score[j, i] = pose[:, 2]
        fake_anno['keypoint'] = keypoint
        fake_anno['keypoint_score'] = keypoint_score

    if fake_anno['keypoint'] is None:
        action_label = ''
    else:
        results = inference_recognizer(model, fake_anno)
        print(results)
        print(results[0][0])
        action_label = label_map[results[0][0]]
        action_label_score = results[0][1]
        print(action_label)
        print(action_label_score)
        

    pose_model = init_pose_model(args.pose_config, args.pose_checkpoint,
                                 args.device)
    vis_frames = [
        vis_pose_result(pose_model, frame_paths[i], pose_results[i])
        for i in range(num_frame)
    ]
    for frame in vis_frames:
        cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE,
                    FONTCOLOR, THICKNESS, LINETYPE)

    vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24)
    vid.write_videofile(args.out_filename, remove_temp=True)

    tmp_frame_dir = osp.dirname(frame_paths[0])
    shutil.rmtree(tmp_frame_dir)


if __name__ == '__main__':
    main()

rgb_only

configs/rgbpose_conv3d/rgb_only.py修改如下:

先注释掉配置文件第50行到第52行

inference

pyskl/apis/inference.py修改如下:

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import os
import os.path as osp
import re
import torch
import warnings
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from operator import itemgetter

from pyskl.core import OutputHook
from pyskl.datasets.pipelines import Compose
from pyskl.models import build_recognizer
from pyskl.utils import cache_checkpoint
EPS = 1e-3

def init_recognizer(config, checkpoint=None, device='cuda:0', **kwargs):
    """Initialize a recognizer from config file.

    Args:
        config (str | :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str | None, optional): Checkpoint path/url. If set to None,
            the model will not load any weights. Default: None.
        device (str | :obj:`torch.device`): The desired device of returned
            tensor. Default: 'cuda:0'.

    Returns:
        nn.Module: The constructed recognizer.
    """
    if 'use_frames' in kwargs:
        warnings.warn('The argument `use_frames` is deprecated PR #1191. '
                      'Now you can use models trained with frames or videos '
                      'arbitrarily. ')

    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')

    # pretrained model is unnecessary since we directly load checkpoint later
    config.model.backbone.pretrained = None
    model = build_recognizer(config.model)

    if checkpoint is not None:
        checkpoint = cache_checkpoint(checkpoint)
        load_checkpoint(model, checkpoint, map_location='cpu')
    model.cfg = config
    model.to(device)
    model.eval()
    return model


def inference_recognizer(model, video, outputs=None, as_tensor=True, **kwargs):
    """Inference a video with the recognizer.

    Args:
        model (nn.Module): The loaded recognizer.
        video (str | dict | ndarray): The video file path / url or the
            rawframes directory path / results dictionary (the input of
            pipeline) / a 4D array T x H x W x 3 (The input video).
        outputs (list(str) | tuple(str) | str | None) : Names of layers whose
            outputs need to be returned, default: None.
        as_tensor (bool): Same as that in ``OutputHook``. Default: True.

    Returns:
        dict[tuple(str, float)]: Top-5 recognition result dict.
        dict[torch.tensor | np.ndarray]:
            Output feature maps from layers specified in `outputs`.
    """
    print("OK")
    print(kwargs)   ## {}是空字典
    print("OK")
    print('use_frames' in kwargs) ## False
    print('label_path' in kwargs) ## False
    if 'use_frames' in kwargs:
        warnings.warn('The argument `use_frames` is deprecated PR #1191. '
                      'Now you can use models trained with frames or videos '
                      'arbitrarily. ')
    if 'label_path' in kwargs:
        warnings.warn('The argument `use_frames` is deprecated PR #1191. '
                      'Now the label file is not needed in '
                      'inference_recognizer. ')

    input_flag = None
    print(type(video)) ## <class 'dict'>字典类型
    print(video)  ## 发现video是一个字典类型
    print(video['keypoint'].shape) ## (2, 72, 17, 2)
    print(video['keypoint_score'].shape) ## (2, 72, 17)
    print(isinstance(video, dict)) ## True
    
    if isinstance(video, dict):
        input_flag = 'dict'
        print("OK")
        print(input_flag)  ## dict
    elif isinstance(video, np.ndarray):
        assert len(video.shape) == 4, 'The shape should be T x H x W x C'
        input_flag = 'array'
    elif isinstance(video, str) and video.startswith('http'):
        input_flag = 'video'
    elif isinstance(video, str) and osp.exists(video):
        if osp.isfile(video):
            input_flag = 'video'
        if osp.isdir(video):
            input_flag = 'rawframes'
    else:
        raise RuntimeError('The type of argument video is not supported: '
                           f'{type(video)}')

    print(isinstance(outputs, str)) ## False
    print(type(outputs)) ## <class 'NoneType'>
    print(outputs) ## None
    if isinstance(outputs, str):
        outputs = (outputs, )
        # print(outputs)
    assert outputs is None or isinstance(outputs, (tuple, list))

    cfg = model.cfg
    print(type(cfg)) ## <class 'mmcv.utils.config.Config'>
    print(model.parameters()) ## <generator object Module.parameters at 0x7f8343631ed0>
    device = next(model.parameters()).device  # model device
    print(device) ## cuda:0
    # build the data pipeline
    test_pipeline = cfg.data.test.pipeline
    print(test_pipeline)
    print(type(test_pipeline)) ## <class 'list'>
    # Alter data pipelines & prepare inputs
    if input_flag == 'dict':
        data = video
        print(data)
        print("OK")
    if input_flag == 'array':
        modality_map = {2: 'Flow', 3: 'RGB'}
        modality = modality_map.get(video.shape[-1])
        data = dict(
            total_frames=video.shape[0],
            label=-1,
            start_index=0,
            array=video,
            modality=modality)
        for i in range(len(test_pipeline)):
            if 'Decode' in test_pipeline[i]['type']:
                test_pipeline[i] = dict(type='ArrayDecode')
    if input_flag == 'video':
        data = dict(filename=video, label=-1, start_index=0, modality='RGB')
        if 'Init' not in test_pipeline[0]['type']:
            test_pipeline = [dict(type='OpenCVInit')] + test_pipeline
        else:
            test_pipeline[0] = dict(type='OpenCVInit')
        for i in range(len(test_pipeline)):
            if 'Decode' in test_pipeline[i]['type']:
                test_pipeline[i] = dict(type='OpenCVDecode')
    if input_flag == 'rawframes':
        filename_tmpl = cfg.data.test.get('filename_tmpl', 'img_{:05}.jpg')
        modality = cfg.data.test.get('modality', 'RGB')
        start_index = cfg.data.test.get('start_index', 1)

        # count the number of frames that match the format of `filename_tmpl`
        # RGB pattern example: img_{:05}.jpg -> ^img_\d+.jpg$
        # Flow patteren example: {}_{:05d}.jpg -> ^x_\d+.jpg$
        pattern = f'^{filename_tmpl}$'
        if modality == 'Flow':
            pattern = pattern.replace('{}', 'x')
        pattern = pattern.replace(
            pattern[pattern.find('{'):pattern.find('}') + 1], '\\d+')
        total_frames = len(
            list(
                filter(lambda x: re.match(pattern, x) is not None,
                       os.listdir(video))))
        data = dict(
            frame_dir=video,
            total_frames=total_frames,
            label=-1,
            start_index=start_index,
            filename_tmpl=filename_tmpl,
            modality=modality)
        if 'Init' in test_pipeline[0]['type']:
            test_pipeline = test_pipeline[1:]
        for i in range(len(test_pipeline)):
            if 'Decode' in test_pipeline[i]['type']:
                test_pipeline[i] = dict(type='RawFrameDecode')

    test_pipeline = Compose(test_pipeline)
    print(test_pipeline)
    print(data)
    print(type(data))
    print("完毕")
    data = test_pipeline(data)
    print("begin")
    print(data.keys())  ## dict_keys(['imgs', 'label'])
    print("end")
    print(type(data))   ## <class 'dict'>
    if 'imgs' in data:
        imgs = data['imgs']
    
        img_shape_value = data['img_shape']
        print(img_shape_value)

        num_clips_value = data['num_clips']
        print(num_clips_value) ## 10
        clip_len_value = data['clip_len']
        print(clip_len_value) ## 48


        if isinstance(clip_len_value, dict):
            clip_len = clip_len_value['RGB']
            print(clip_len)
        print(imgs.shape)
        imgs = imgs.reshape((-1, num_clips_value, clip_len) + imgs.shape[1:])
        print(imgs.shape)
        # N_crops x N_clips x L x H x W x C
        imgs = np.transpose(imgs, (0, 1, 5, 2, 3, 4))
        print(imgs.shape)
        # N_crops x N_clips x C x L x H x W
        imgs = imgs.reshape((-1, ) + imgs.shape[2:])
        print(imgs.shape)
    
    ## 可以把data就是理解为results,现在就是results作为了Resize类的输出,也是作为GeneratePoseTarget类的输入
    
#     all_kps = data['keypoint']
#     kp_shape = all_kps.shape
#     print(kp_shape) ## (2, 480, 17, 2)
#     if 'keypoint_score' in data:
#         print("data字典存在keypoint_score键")
#         all_kpscores = data['keypoint_score']
#         print(all_kpscores.shape)  ## (2, 480, 17)
#     else:
#         all_kpscores = np.ones(kp_shape[:-1], dtype=np.float32)
#     img_h, img_w = data['img_shape']
#     print(img_h, img_w)
#     # scale img_h, img_w and kps
#     img_h = int(img_h * 1.0 + 0.5)
#     print(img_h)
#     img_w = int(img_w * 1.0 + 0.5)
#     print(img_w)
#     all_kps[..., :2] *= 1.0
#     num_frame = kp_shape[1]
#     print(num_frame) ## 480
#     num_c = 0
#     if True:
#         num_c += all_kps.shape[2]
#         print(all_kps.shape[2]) ## 17
#         print(num_c) ## 17
#     if False:
#         num_c += len(self.skeletons)
#     ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32)  ## 初始化 ret 数组
#     print(ret.shape) ## (480, 17, 64, 64)
    
    
    # for i in range(num_frame):     ## 遍历每一帧
    #     # M, V, C
    #     print(num_frame)
    #     print(i)
    #     kps = all_kps[:, i]
    #     print(kps.shape) # (2, 17, 2)
    #     # M, C
    #     kpscores = all_kpscores[:, i]
    #     print(kpscores.shape) # (2, 17)
        
        ## 调用 self.generate_heatmap 方法生成当前帧的热图,并存储在 ret 中,意味着每一帧都会生成一个对应的热图。

#         # self.generate_heatmap(ret[i], kps, kpscores)
        ## print(ret.shape) ## (480, 17, 64, 64) 返回值shape依然是(480, 17, 64, 64)
        
        # num_kp = kps.shape[1]
        # print(num_kp) ## 17
        # for j in range(num_kp):
            # print(j) ##循环遍历17个关键点
#             # self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i])
#             sigma = 0.6
#             img_h, img_w = ret[i][j].shape
#             print(img_h, img_w)
#             for center, max_value in zip(kps[:, j], kpscores[:, j]):
#                 print(center, max_value)
#                 if max_value < EPS:
#                     print("OK")
#                     continue
#                 mu_x, mu_y = center[0], center[1]
#                 print(mu_x, mu_y)
#                 st_x = max(int(mu_x - 3 * sigma), 0)
#                 print(st_x)
#                 ed_x = min(int(mu_x + 3 * sigma) + 1, img_w)
#                 print(ed_x)
#                 st_y = max(int(mu_y - 3 * sigma), 0)
#                 print(st_y)
#                 ed_y = min(int(mu_y + 3 * sigma) + 1, img_h)
#                 print(ed_y)
#                 x = np.arange(st_x, ed_x, 1, np.float32)
#                 print(x)
#                 y = np.arange(st_y, ed_y, 1, np.float32)
#                 print(y)
#                 if not (len(x) and len(y)):
#                     print("OK")
#                     continue
#                 print(y.shape)
#                 y = y[:, None]
#                 print(y.shape)
#                 patch = np.exp(-((x - mu_x)**2 + (y - mu_y)**2) / 2 / sigma**2)
#                 print(patch)
#                 print(max_value)
#                 patch = patch * max_value
#                 print(patch)
#                 ret[i][j][st_y:ed_y, st_x:ed_x] = np.maximum(ret[i][j][st_y:ed_y, st_x:ed_x], patch)
#                 print(ret[i][j][st_y:ed_y, st_x:ed_x])
    
    
    # img_value = data['imgs']
    # num_clips_value = data['num_clips']
    # print(num_clips_value) ## 10
    # clip_len_value = data['clip_len']
    # print(clip_len_value) ## 48
    
#     img_value = img_value.reshape((-1, 10, 48) + img_value.shape[1:])
#     print(img_value.shape) ## (2, 10, 48, 17, 64, 64)
#     img_value = np.transpose(img_value, (0, 1, 3, 2, 4, 5))
#     print(img_value.shape) ## (2, 10, 17, 48, 64, 64)
    
#     img_value = img_value.reshape((-1, ) + img_value.shape[2:])
    
#     print(img_value.shape) ## (20, 17, 48, 64, 64)
    
    
    
#     # print(img_value)
#     print(type(img_value))
#     # print(img_value.dim())
#     print(img_value.shape)
    
    # heatmap_imgs_value = data['heatmap_imgs']
    # print(heatmap_imgs_value)
    # print(type(heatmap_imgs_value))
    # print(heatmap_imgs_value.shape)
    # print(img_value.shape)
    # print(img_value.dim())
    # print(heatmap_imgs_value.dim())
#     data = collate([data], samples_per_gpu=1)
#     # print(data)
#     if next(model.parameters()).is_cuda:
#         # scatter to specified GPU
#         data = scatter(data, [device])[0]

#     print(type(outputs))
#     # forward the model
#     with OutputHook(model, outputs=outputs, as_tensor=as_tensor) as h:
#         with torch.no_grad():
#             # scores = model(return_loss=False, **data)[0]
#             output = model(return_loss=False,**data)
#             print('Model output' , output)

#             if isinstance(output[0], dict):
#                 if 'pose' in output[0] and 'rgb' in output[0]:
#                         # 取出 rgb 和 pose 数据
#                         rgb_data = output[0]['rgb']
#                         pose_data = output[0]['pose']

#                         # 比较两种模态,取最大值
#                         scores = np.maximum(rgb_data, pose_data)
#                 else:
#                      raise ValueError('输出字典不包括‘rgb’和‘pose’')
#             else:
#                  scores = output[0]
            
#             print('scores type ', type(scores))
#             print('socres shape', scores.shape)
#             print(scores.ndim)
#             if scores.ndim == 1:
#                 num_classes = scores.shape[0]
#                 print(num_classes)
#             elif scores.ndim > 1:
#                 num_classes = scores.shape[-1]
#             else:
#                 raise ValueError(f'Unexpected shape: {scores.shape}')
            
#             score_tuples = tuple(zip(range(num_classes), scores))
#             score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)

#             top5_label = score_sorted[:5]
#             if outputs:
#                 return top5_label, h.layer_outputs
#             return top5_label

print(data.keys())  

begin
dict_keys(['frame_dir', 'label', 'img_shape', 'original_shape', 'start_index', 'modality', 'total_frames', 'test_mode', 'keypoint', 'keypoint_score', 'filename', 'RGB_inds', 'clip_len', 'frame_interval', 'num_clips', 'imgs', 'scale_factor', 'keep_ratio', 'img_norm_cfg'])
end

if 'imgs' in data:
        imgs = data['imgs']
        # print(imgs)
        print(type(imgs))
        print(imgs.shape)

<class 'numpy.ndarray'>
(80, 224, 224, 3)

 if 'imgs' in data:
        imgs = data['imgs']
        # print(imgs)
        print(type(imgs))
        print(imgs.shape)
    
        img_shape_value = data['img_shape']
        print(img_shape_value)

        num_clips_value = data['num_clips']
        print(num_clips_value) 
        clip_len_value = data['clip_len']
        print(clip_len_value) 

<class 'numpy.ndarray'>
(80, 224, 224, 3)
(224, 224)
10
{'RGB': 8}

if 'imgs' in data:
        imgs = data['imgs']
        # print(imgs)
        print(type(imgs))
        print(imgs.shape)
    
        img_shape_value = data['img_shape']
        print(img_shape_value)

        num_clips_value = data['num_clips']
        print(num_clips_value) 
        clip_len_value = data['clip_len']
        print(clip_len_value) 


        if isinstance(clip_len_value, dict):
            clip_len = clip_len_value['RGB']
            print(clip_len)
        print(imgs.shape)
        imgs = imgs.reshape((-1, num_clips_value, clip_len) + imgs.shape[1:])
        print(imgs.shape)
        # N_crops x N_clips x L x H x W x C
        imgs = np.transpose(imgs, (0, 1, 5, 2, 3, 4))
        print(imgs.shape)
        # N_crops x N_clips x C x L x H x W
        imgs = imgs.reshape((-1, ) + imgs.shape[2:])
        print(imgs.shape)

<class 'numpy.ndarray'>
(80, 224, 224, 3)
(224, 224)
10
{'RGB': 8}
8
(80, 224, 224, 3)
(1, 10, 8, 224, 224, 3)
(1, 10, 3, 8, 224, 224)
(10, 3, 8, 224, 224)

总结

# N_crops x N_clips x L x H x W x C这个是对应的(1, 10, 8, 224, 224, 3)

# N_crops x N_clips x C x L x H x W这个是对应的(1, 10, 3, 8, 224, 224)

N_crops对应的是:1

N_clips对应的是:10

num_clips_value = data['num_clips']
print(num_clips_value) 

C对应的是:3

L对应的是:8

if isinstance(clip_len_value, dict):
            clip_len = clip_len_value['RGB']
            print(clip_len)  ## 8

H和W对应的是:(224, 224)

最终的格式为'NCTHW'

总结一遍:

dict(type='FormatShape', input_format='NCTHW')

功能为:

输入为:dict(type='Normalize', **img_norm_cfg)处理之后的输出

输出为:print(imgs.shape)是(10, 3, 8, 224, 224)

NCTHW分别对应10, 3, 8, 224, 224

到这里发现确实用到了RGB模态。

https://github.com/kennymckormick/pyskl/blob/main/configs/rgbpose_conv3d/rgb_only.py

  • 9
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值