demo_skeleton
demo/demo_skeleton.py修改如下:
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import cv2
import mmcv
import numpy as np
import os
import os.path as osp
import shutil
import torch
import warnings
from scipy.optimize import linear_sum_assignment
from pyskl.apis import inference_recognizer, init_recognizer
try:
from mmdet.apis import inference_detector, init_detector
except (ImportError, ModuleNotFoundError):
def inference_detector(*args, **kwargs):
pass
def init_detector(*args, **kwargs):
pass
warnings.warn(
'Failed to import `inference_detector` and `init_detector` from `mmdet.apis`. '
'Make sure you can successfully import these if you want to use related features. '
)
try:
from mmpose.apis import inference_top_down_pose_model, init_pose_model, vis_pose_result
except (ImportError, ModuleNotFoundError):
def init_pose_model(*args, **kwargs):
pass
def inference_top_down_pose_model(*args, **kwargs):
pass
def vis_pose_result(*args, **kwargs):
pass
warnings.warn(
'Failed to import `init_pose_model`, `inference_top_down_pose_model`, `vis_pose_result` from '
'`mmpose.apis`. Make sure you can successfully import these if you want to use related features. '
)
try:
import moviepy.editor as mpy
except ImportError:
raise ImportError('Please install moviepy to enable output file')
FONTFACE = cv2.FONT_HERSHEY_DUPLEX
FONTSCALE = 0.75
FONTCOLOR = (255, 255, 255) # BGR, white
THICKNESS = 1
LINETYPE = 1
def parse_args():
parser = argparse.ArgumentParser(description='PoseC3D demo')
parser.add_argument('video', help='video file/url')
parser.add_argument('out_filename', help='output filename')
parser.add_argument(
'--config',
default='configs/rgbpose_conv3d/rgb_only.py',
help='skeleton action recognition config file path')
parser.add_argument(
'--checkpoint',
default='https://download.openmmlab.com/mmaction/pyskl/ckpt/rgbpose_conv3d/rgb_only.pth',
help='skeleton action recognition checkpoint file/url')
parser.add_argument(
'--det-config',
default='demo/faster_rcnn_r50_fpn_1x_coco-person.py',
help='human detection config file path (from mmdet)')
parser.add_argument(
'--det-checkpoint',
default=('https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco-person/'
'faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth'),
help='human detection checkpoint file/url')
parser.add_argument(
'--pose-config',
default='demo/hrnet_w32_coco_256x192.py',
help='human pose estimation config file path (from mmpose)')
parser.add_argument(
'--pose-checkpoint',
default='https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth',
help='human pose estimation checkpoint file/url')
parser.add_argument(
'--det-score-thr',
type=float,
default=0.9,
help='the threshold of human detection score')
parser.add_argument(
'--label-map',
default='tools/data/label_map/nturgbd_120.txt',
help='label map file')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--short-side',
type=int,
default=480,
help='specify the short-side length of the image')
args = parser.parse_args()
print('output filename', args.out_filename)
return args
def frame_extraction(video_path, short_side):
"""Extract frames given video_path.
Args:
video_path (str): The video_path.
"""
# Load the video, extract frames into ./tmp/video_name
target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0]))
os.makedirs(target_dir, exist_ok=True)
# Should be able to handle videos up to several hours
frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
vid = cv2.VideoCapture(video_path)
frames = []
frame_paths = []
flag, frame = vid.read()
cnt = 0
new_h, new_w = None, None
while flag:
if new_h is None:
h, w, _ = frame.shape
new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf))
frame = mmcv.imresize(frame, (new_w, new_h))
frames.append(frame)
frame_path = frame_tmpl.format(cnt + 1)
frame_paths.append(frame_path)
cv2.imwrite(frame_path, frame)
cnt += 1
flag, frame = vid.read()
return frame_paths, frames
def detection_inference(args, frame_paths):
"""Detect human boxes given frame paths.
Args:
args (argparse.Namespace): The arguments.
frame_paths (list[str]): The paths of frames to do detection inference.
Returns:
list[np.ndarray]: The human detection results.
"""
model = init_detector(args.det_config, args.det_checkpoint, args.device)
assert model is not None, ('Failed to build the detection model. Check if you have installed mmcv-full properly. '
'You should first install mmcv-full successfully, then install mmdet, mmpose. ')
assert model.CLASSES[0] == 'person', 'We require you to use a detector trained on COCO'
results = []
print('Performing Human Detection for each frame')
prog_bar = mmcv.ProgressBar(len(frame_paths))
for frame_path in frame_paths:
result = inference_detector(model, frame_path)
# We only keep human detections with score larger than det_score_thr
result = result[0][result[0][:, 4] >= args.det_score_thr]
results.append(result)
prog_bar.update()
return results
def pose_inference(args, frame_paths, det_results):
model = init_pose_model(args.pose_config, args.pose_checkpoint,
args.device)
ret = []
print('Performing Human Pose Estimation for each frame')
prog_bar = mmcv.ProgressBar(len(frame_paths))
for f, d in zip(frame_paths, det_results):
# Align input format
d = [dict(bbox=x) for x in list(d)]
pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
ret.append(pose)
prog_bar.update()
return ret
def dist_ske(ske1, ske2):
dist = np.linalg.norm(ske1[:, :2] - ske2[:, :2], axis=1) * 2
diff = np.abs(ske1[:, 2] - ske2[:, 2])
return np.sum(np.maximum(dist, diff))
def pose_tracking(pose_results, max_tracks=2, thre=30):
tracks, num_tracks = [], 0
num_joints = None
for idx, poses in enumerate(pose_results):
if len(poses) == 0:
continue
if num_joints is None:
num_joints = poses[0].shape[0]
track_proposals = [t for t in tracks if t['data'][-1][0] > idx - thre]
n, m = len(track_proposals), len(poses)
scores = np.zeros((n, m))
for i in range(n):
for j in range(m):
scores[i][j] = dist_ske(track_proposals[i]['data'][-1][1], poses[j])
row, col = linear_sum_assignment(scores)
for r, c in zip(row, col):
track_proposals[r]['data'].append((idx, poses[c]))
if m > n:
for j in range(m):
if j not in col:
num_tracks += 1
new_track = dict(data=[])
new_track['track_id'] = num_tracks
new_track['data'] = [(idx, poses[j])]
tracks.append(new_track)
if num_joints is None:
return None, None
tracks.sort(key=lambda x: -len(x['data']))
result = np.zeros((max_tracks, len(pose_results), num_joints, 3), dtype=np.float16)
for i, track in enumerate(tracks[:max_tracks]):
for item in track['data']:
idx, pose = item
result[i, idx] = pose
return result[..., :2], result[..., 2]
def main():
args = parse_args()
frame_paths, original_frames = frame_extraction(args.video,
args.short_side)
num_frame = len(frame_paths)
h, w, _ = original_frames[0].shape
config = mmcv.Config.fromfile(args.config)
config.data.test.pipeline = [x for x in config.data.test.pipeline if x['type'] != 'DecompressPose']
# Are we using GCN for Infernece?
GCN_flag = 'GCN' in config.model.type
GCN_nperson = None
if GCN_flag:
format_op = [op for op in config.data.test.pipeline if op['type'] == 'FormatGCNInput'][0]
# We will set the default value of GCN_nperson to 2, which is
# the default arg of FormatGCNInput
GCN_nperson = format_op.get('num_person', 2)
model = init_recognizer(config, args.checkpoint, args.device)
# Load label_map
label_map = [x.strip() for x in open(args.label_map).readlines()]
# Get Human detection results
det_results = detection_inference(args, frame_paths)
torch.cuda.empty_cache()
pose_results = pose_inference(args, frame_paths, det_results)
torch.cuda.empty_cache()
fake_anno = dict(
frame_dir='',
label=-1,
img_shape=(h, w),
original_shape=(h, w),
start_index=0,
modality='Pose',
total_frames=num_frame,
test_mode = True
)
print('fake_anno', fake_anno)
if not fake_anno['frame_dir']:
fake_anno['frame_dir'] = args.video
print(fake_anno['frame_dir'])
if GCN_flag:
# We will keep at most `GCN_nperson` persons per frame.
tracking_inputs = [[pose['keypoints'] for pose in poses] for poses in pose_results]
keypoint, keypoint_score = pose_tracking(tracking_inputs, max_tracks=GCN_nperson)
fake_anno['keypoint'] = keypoint
fake_anno['keypoint_score'] = keypoint_score
else:
num_person = max([len(x) for x in pose_results])
# Current PoseC3D models are trained on COCO-keypoints (17 keypoints)
num_keypoint = 17
keypoint = np.zeros((num_person, num_frame, num_keypoint, 2),
dtype=np.float16)
keypoint_score = np.zeros((num_person, num_frame, num_keypoint),
dtype=np.float16)
for i, poses in enumerate(pose_results):
for j, pose in enumerate(poses):
pose = pose['keypoints']
keypoint[j, i] = pose[:, :2]
keypoint_score[j, i] = pose[:, 2]
fake_anno['keypoint'] = keypoint
fake_anno['keypoint_score'] = keypoint_score
if fake_anno['keypoint'] is None:
action_label = ''
else:
results = inference_recognizer(model, fake_anno)
print(results)
print(results[0][0])
action_label = label_map[results[0][0]]
action_label_score = results[0][1]
print(action_label)
print(action_label_score)
pose_model = init_pose_model(args.pose_config, args.pose_checkpoint,
args.device)
vis_frames = [
vis_pose_result(pose_model, frame_paths[i], pose_results[i])
for i in range(num_frame)
]
for frame in vis_frames:
cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE,
FONTCOLOR, THICKNESS, LINETYPE)
vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24)
vid.write_videofile(args.out_filename, remove_temp=True)
tmp_frame_dir = osp.dirname(frame_paths[0])
shutil.rmtree(tmp_frame_dir)
if __name__ == '__main__':
main()
rgb_only
configs/rgbpose_conv3d/rgb_only.py修改如下:
先注释掉配置文件第50行到第52行
inference
pyskl/apis/inference.py修改如下:
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import os
import os.path as osp
import re
import torch
import warnings
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from operator import itemgetter
from pyskl.core import OutputHook
from pyskl.datasets.pipelines import Compose
from pyskl.models import build_recognizer
from pyskl.utils import cache_checkpoint
EPS = 1e-3
def init_recognizer(config, checkpoint=None, device='cuda:0', **kwargs):
"""Initialize a recognizer from config file.
Args:
config (str | :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str | None, optional): Checkpoint path/url. If set to None,
the model will not load any weights. Default: None.
device (str | :obj:`torch.device`): The desired device of returned
tensor. Default: 'cuda:0'.
Returns:
nn.Module: The constructed recognizer.
"""
if 'use_frames' in kwargs:
warnings.warn('The argument `use_frames` is deprecated PR #1191. '
'Now you can use models trained with frames or videos '
'arbitrarily. ')
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
# pretrained model is unnecessary since we directly load checkpoint later
config.model.backbone.pretrained = None
model = build_recognizer(config.model)
if checkpoint is not None:
checkpoint = cache_checkpoint(checkpoint)
load_checkpoint(model, checkpoint, map_location='cpu')
model.cfg = config
model.to(device)
model.eval()
return model
def inference_recognizer(model, video, outputs=None, as_tensor=True, **kwargs):
"""Inference a video with the recognizer.
Args:
model (nn.Module): The loaded recognizer.
video (str | dict | ndarray): The video file path / url or the
rawframes directory path / results dictionary (the input of
pipeline) / a 4D array T x H x W x 3 (The input video).
outputs (list(str) | tuple(str) | str | None) : Names of layers whose
outputs need to be returned, default: None.
as_tensor (bool): Same as that in ``OutputHook``. Default: True.
Returns:
dict[tuple(str, float)]: Top-5 recognition result dict.
dict[torch.tensor | np.ndarray]:
Output feature maps from layers specified in `outputs`.
"""
print("OK")
print(kwargs) ## {}是空字典
print("OK")
print('use_frames' in kwargs) ## False
print('label_path' in kwargs) ## False
if 'use_frames' in kwargs:
warnings.warn('The argument `use_frames` is deprecated PR #1191. '
'Now you can use models trained with frames or videos '
'arbitrarily. ')
if 'label_path' in kwargs:
warnings.warn('The argument `use_frames` is deprecated PR #1191. '
'Now the label file is not needed in '
'inference_recognizer. ')
input_flag = None
print(type(video)) ## <class 'dict'>字典类型
print(video) ## 发现video是一个字典类型
print(video['keypoint'].shape) ## (2, 72, 17, 2)
print(video['keypoint_score'].shape) ## (2, 72, 17)
print(isinstance(video, dict)) ## True
if isinstance(video, dict):
input_flag = 'dict'
print("OK")
print(input_flag) ## dict
elif isinstance(video, np.ndarray):
assert len(video.shape) == 4, 'The shape should be T x H x W x C'
input_flag = 'array'
elif isinstance(video, str) and video.startswith('http'):
input_flag = 'video'
elif isinstance(video, str) and osp.exists(video):
if osp.isfile(video):
input_flag = 'video'
if osp.isdir(video):
input_flag = 'rawframes'
else:
raise RuntimeError('The type of argument video is not supported: '
f'{type(video)}')
print(isinstance(outputs, str)) ## False
print(type(outputs)) ## <class 'NoneType'>
print(outputs) ## None
if isinstance(outputs, str):
outputs = (outputs, )
# print(outputs)
assert outputs is None or isinstance(outputs, (tuple, list))
cfg = model.cfg
print(type(cfg)) ## <class 'mmcv.utils.config.Config'>
print(model.parameters()) ## <generator object Module.parameters at 0x7f8343631ed0>
device = next(model.parameters()).device # model device
print(device) ## cuda:0
# build the data pipeline
test_pipeline = cfg.data.test.pipeline
print(test_pipeline)
print(type(test_pipeline)) ## <class 'list'>
# Alter data pipelines & prepare inputs
if input_flag == 'dict':
data = video
print(data)
print("OK")
if input_flag == 'array':
modality_map = {2: 'Flow', 3: 'RGB'}
modality = modality_map.get(video.shape[-1])
data = dict(
total_frames=video.shape[0],
label=-1,
start_index=0,
array=video,
modality=modality)
for i in range(len(test_pipeline)):
if 'Decode' in test_pipeline[i]['type']:
test_pipeline[i] = dict(type='ArrayDecode')
if input_flag == 'video':
data = dict(filename=video, label=-1, start_index=0, modality='RGB')
if 'Init' not in test_pipeline[0]['type']:
test_pipeline = [dict(type='OpenCVInit')] + test_pipeline
else:
test_pipeline[0] = dict(type='OpenCVInit')
for i in range(len(test_pipeline)):
if 'Decode' in test_pipeline[i]['type']:
test_pipeline[i] = dict(type='OpenCVDecode')
if input_flag == 'rawframes':
filename_tmpl = cfg.data.test.get('filename_tmpl', 'img_{:05}.jpg')
modality = cfg.data.test.get('modality', 'RGB')
start_index = cfg.data.test.get('start_index', 1)
# count the number of frames that match the format of `filename_tmpl`
# RGB pattern example: img_{:05}.jpg -> ^img_\d+.jpg$
# Flow patteren example: {}_{:05d}.jpg -> ^x_\d+.jpg$
pattern = f'^{filename_tmpl}$'
if modality == 'Flow':
pattern = pattern.replace('{}', 'x')
pattern = pattern.replace(
pattern[pattern.find('{'):pattern.find('}') + 1], '\\d+')
total_frames = len(
list(
filter(lambda x: re.match(pattern, x) is not None,
os.listdir(video))))
data = dict(
frame_dir=video,
total_frames=total_frames,
label=-1,
start_index=start_index,
filename_tmpl=filename_tmpl,
modality=modality)
if 'Init' in test_pipeline[0]['type']:
test_pipeline = test_pipeline[1:]
for i in range(len(test_pipeline)):
if 'Decode' in test_pipeline[i]['type']:
test_pipeline[i] = dict(type='RawFrameDecode')
test_pipeline = Compose(test_pipeline)
print(test_pipeline)
print(data)
print(type(data))
print("完毕")
data = test_pipeline(data)
print("begin")
print(data.keys()) ## dict_keys(['imgs', 'label'])
print("end")
print(type(data)) ## <class 'dict'>
if 'imgs' in data:
imgs = data['imgs']
img_shape_value = data['img_shape']
print(img_shape_value)
num_clips_value = data['num_clips']
print(num_clips_value) ## 10
clip_len_value = data['clip_len']
print(clip_len_value) ## 48
if isinstance(clip_len_value, dict):
clip_len = clip_len_value['RGB']
print(clip_len)
print(imgs.shape)
imgs = imgs.reshape((-1, num_clips_value, clip_len) + imgs.shape[1:])
print(imgs.shape)
# N_crops x N_clips x L x H x W x C
imgs = np.transpose(imgs, (0, 1, 5, 2, 3, 4))
print(imgs.shape)
# N_crops x N_clips x C x L x H x W
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
print(imgs.shape)
## 可以把data就是理解为results,现在就是results作为了Resize类的输出,也是作为GeneratePoseTarget类的输入
# all_kps = data['keypoint']
# kp_shape = all_kps.shape
# print(kp_shape) ## (2, 480, 17, 2)
# if 'keypoint_score' in data:
# print("data字典存在keypoint_score键")
# all_kpscores = data['keypoint_score']
# print(all_kpscores.shape) ## (2, 480, 17)
# else:
# all_kpscores = np.ones(kp_shape[:-1], dtype=np.float32)
# img_h, img_w = data['img_shape']
# print(img_h, img_w)
# # scale img_h, img_w and kps
# img_h = int(img_h * 1.0 + 0.5)
# print(img_h)
# img_w = int(img_w * 1.0 + 0.5)
# print(img_w)
# all_kps[..., :2] *= 1.0
# num_frame = kp_shape[1]
# print(num_frame) ## 480
# num_c = 0
# if True:
# num_c += all_kps.shape[2]
# print(all_kps.shape[2]) ## 17
# print(num_c) ## 17
# if False:
# num_c += len(self.skeletons)
# ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32) ## 初始化 ret 数组
# print(ret.shape) ## (480, 17, 64, 64)
# for i in range(num_frame): ## 遍历每一帧
# # M, V, C
# print(num_frame)
# print(i)
# kps = all_kps[:, i]
# print(kps.shape) # (2, 17, 2)
# # M, C
# kpscores = all_kpscores[:, i]
# print(kpscores.shape) # (2, 17)
## 调用 self.generate_heatmap 方法生成当前帧的热图,并存储在 ret 中,意味着每一帧都会生成一个对应的热图。
# # self.generate_heatmap(ret[i], kps, kpscores)
## print(ret.shape) ## (480, 17, 64, 64) 返回值shape依然是(480, 17, 64, 64)
# num_kp = kps.shape[1]
# print(num_kp) ## 17
# for j in range(num_kp):
# print(j) ##循环遍历17个关键点
# # self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i])
# sigma = 0.6
# img_h, img_w = ret[i][j].shape
# print(img_h, img_w)
# for center, max_value in zip(kps[:, j], kpscores[:, j]):
# print(center, max_value)
# if max_value < EPS:
# print("OK")
# continue
# mu_x, mu_y = center[0], center[1]
# print(mu_x, mu_y)
# st_x = max(int(mu_x - 3 * sigma), 0)
# print(st_x)
# ed_x = min(int(mu_x + 3 * sigma) + 1, img_w)
# print(ed_x)
# st_y = max(int(mu_y - 3 * sigma), 0)
# print(st_y)
# ed_y = min(int(mu_y + 3 * sigma) + 1, img_h)
# print(ed_y)
# x = np.arange(st_x, ed_x, 1, np.float32)
# print(x)
# y = np.arange(st_y, ed_y, 1, np.float32)
# print(y)
# if not (len(x) and len(y)):
# print("OK")
# continue
# print(y.shape)
# y = y[:, None]
# print(y.shape)
# patch = np.exp(-((x - mu_x)**2 + (y - mu_y)**2) / 2 / sigma**2)
# print(patch)
# print(max_value)
# patch = patch * max_value
# print(patch)
# ret[i][j][st_y:ed_y, st_x:ed_x] = np.maximum(ret[i][j][st_y:ed_y, st_x:ed_x], patch)
# print(ret[i][j][st_y:ed_y, st_x:ed_x])
# img_value = data['imgs']
# num_clips_value = data['num_clips']
# print(num_clips_value) ## 10
# clip_len_value = data['clip_len']
# print(clip_len_value) ## 48
# img_value = img_value.reshape((-1, 10, 48) + img_value.shape[1:])
# print(img_value.shape) ## (2, 10, 48, 17, 64, 64)
# img_value = np.transpose(img_value, (0, 1, 3, 2, 4, 5))
# print(img_value.shape) ## (2, 10, 17, 48, 64, 64)
# img_value = img_value.reshape((-1, ) + img_value.shape[2:])
# print(img_value.shape) ## (20, 17, 48, 64, 64)
# # print(img_value)
# print(type(img_value))
# # print(img_value.dim())
# print(img_value.shape)
# heatmap_imgs_value = data['heatmap_imgs']
# print(heatmap_imgs_value)
# print(type(heatmap_imgs_value))
# print(heatmap_imgs_value.shape)
# print(img_value.shape)
# print(img_value.dim())
# print(heatmap_imgs_value.dim())
# data = collate([data], samples_per_gpu=1)
# # print(data)
# if next(model.parameters()).is_cuda:
# # scatter to specified GPU
# data = scatter(data, [device])[0]
# print(type(outputs))
# # forward the model
# with OutputHook(model, outputs=outputs, as_tensor=as_tensor) as h:
# with torch.no_grad():
# # scores = model(return_loss=False, **data)[0]
# output = model(return_loss=False,**data)
# print('Model output' , output)
# if isinstance(output[0], dict):
# if 'pose' in output[0] and 'rgb' in output[0]:
# # 取出 rgb 和 pose 数据
# rgb_data = output[0]['rgb']
# pose_data = output[0]['pose']
# # 比较两种模态,取最大值
# scores = np.maximum(rgb_data, pose_data)
# else:
# raise ValueError('输出字典不包括‘rgb’和‘pose’')
# else:
# scores = output[0]
# print('scores type ', type(scores))
# print('socres shape', scores.shape)
# print(scores.ndim)
# if scores.ndim == 1:
# num_classes = scores.shape[0]
# print(num_classes)
# elif scores.ndim > 1:
# num_classes = scores.shape[-1]
# else:
# raise ValueError(f'Unexpected shape: {scores.shape}')
# score_tuples = tuple(zip(range(num_classes), scores))
# score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
# top5_label = score_sorted[:5]
# if outputs:
# return top5_label, h.layer_outputs
# return top5_label
print(data.keys())
begin
dict_keys(['frame_dir', 'label', 'img_shape', 'original_shape', 'start_index', 'modality', 'total_frames', 'test_mode', 'keypoint', 'keypoint_score', 'filename', 'RGB_inds', 'clip_len', 'frame_interval', 'num_clips', 'imgs', 'scale_factor', 'keep_ratio', 'img_norm_cfg'])
end
if 'imgs' in data:
imgs = data['imgs']
# print(imgs)
print(type(imgs))
print(imgs.shape)
<class 'numpy.ndarray'>
(80, 224, 224, 3)
if 'imgs' in data:
imgs = data['imgs']
# print(imgs)
print(type(imgs))
print(imgs.shape)
img_shape_value = data['img_shape']
print(img_shape_value)num_clips_value = data['num_clips']
print(num_clips_value)
clip_len_value = data['clip_len']
print(clip_len_value)
<class 'numpy.ndarray'>
(80, 224, 224, 3)
(224, 224)
10
{'RGB': 8}
if 'imgs' in data:
imgs = data['imgs']
# print(imgs)
print(type(imgs))
print(imgs.shape)
img_shape_value = data['img_shape']
print(img_shape_value)num_clips_value = data['num_clips']
print(num_clips_value)
clip_len_value = data['clip_len']
print(clip_len_value)
if isinstance(clip_len_value, dict):
clip_len = clip_len_value['RGB']
print(clip_len)
print(imgs.shape)
imgs = imgs.reshape((-1, num_clips_value, clip_len) + imgs.shape[1:])
print(imgs.shape)
# N_crops x N_clips x L x H x W x C
imgs = np.transpose(imgs, (0, 1, 5, 2, 3, 4))
print(imgs.shape)
# N_crops x N_clips x C x L x H x W
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
print(imgs.shape)
<class 'numpy.ndarray'>
(80, 224, 224, 3)
(224, 224)
10
{'RGB': 8}
8
(80, 224, 224, 3)
(1, 10, 8, 224, 224, 3)
(1, 10, 3, 8, 224, 224)
(10, 3, 8, 224, 224)
总结
# N_crops x N_clips x L x H x W x C这个是对应的(1, 10, 8, 224, 224, 3)
# N_crops x N_clips x C x L x H x W这个是对应的(1, 10, 3, 8, 224, 224)
N_crops对应的是:1
N_clips对应的是:10
num_clips_value = data['num_clips']
print(num_clips_value)
C对应的是:3
L对应的是:8
if isinstance(clip_len_value, dict):
clip_len = clip_len_value['RGB']
print(clip_len) ## 8
H和W对应的是:(224, 224)
最终的格式为'NCTHW'
总结一遍:
dict(type='FormatShape', input_format='NCTHW')
功能为:
输入为:dict(type='Normalize', **img_norm_cfg)处理之后的输出
输出为:print(imgs.shape)是(10, 3, 8, 224, 224)
NCTHW分别对应10, 3, 8, 224, 224
到这里发现确实用到了RGB模态。
https://github.com/kennymckormick/pyskl/blob/main/configs/rgbpose_conv3d/rgb_only.py