AIST++ 笔记

目录

开源项目地址

pkl转空间坐标xyz 代码


开源项目地址

aistplusplus_api/demos at main · google/aistplusplus_api · GitHub

有提取特征,可视化的

pkl转空间坐标xyz 代码

"""Demo code for running visualizer."""
import os
import pickle

import numpy as np
from absl import app
from absl import flags
from aist_plusplus.loader import AISTDataset
from aist_plusplus.visualizer import plot_on_video
from smplx import SMPL
import torch

FLAGS = flags.FLAGS
flags.DEFINE_string('anno_dir', '/home/ruilongli/data/AIST++/', 'input local dictionary for AIST++ annotations.')
flags.DEFINE_string('video_dir', '/home/ruilongli/data/AIST/videos/10M/', 'input local dictionary for AIST Dance Videos.')
flags.DEFINE_string('smpl_dir', r'E:\assets\smpl_model/smpl', 'input local dictionary that stores SMPL data.')
flags.DEFINE_string('video_name', 'gBR_sBM_c01_d04_mBR0_ch01', 'input video name to be visualized.')
flags.DEFINE_string('save_dir', './', 'output local dictionary that stores AIST++ visualization.')
flags.DEFINE_enum('mode', 'SMPL', ['2D', '3D', 'SMPL', 'SMPLMesh'], 'visualize 3D or 2D keypoints, or SMPL joints on image plane.')


def main(_):
    # Parsing data info.
    # aist_dataset = AISTDataset(FLAGS.anno_dir)
    # video_path = os.path.join(FLAGS.video_dir, f'{FLAGS.video_name}.mp4')
    # seq_name, view = AISTDataset.get_seq_name(FLAGS.video_name)
    view_idx = 0#AISTDataset.VIEWS.index(view)
    file_path=r"E:\airuhuo_slice0.pkl"
    if FLAGS.mode == 'SMPL':  # SMPL joints

        with open(file_path, "rb") as f:
            poses = pickle.load(f)
            if "smpl_poses" in poses:
                rots = poses["smpl_poses"]  # (N, 72)
                smpl_poses = rots.reshape(-1, 24 * 3)  # (N, 24, 3)
            elif "poses" in poses:
                rots = poses["poses"]
            elif "pos" in poses:
                rots = poses["q"]  # (N, 72)
                smpl_poses = rots.reshape(-1, 24 , 3)  # (N, 24, 3)
                smpl_trans = poses['pos']  # (N, 3)
            else:
                rots = poses["pred_thetas"]  # (N, 72)
                smpl_poses = rots.reshape(-1, 22*3)  # (N, 24, 3)

        smpl_scaling = poses.get('smpl_scaling',[1])
        smpl_scaling=np.asarray(smpl_scaling)
        # smpl_trans = poses['smpl_trans']  # (N, 3)

        # smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(motion_dir, seq_name)
        smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
        keypoints3d = smpl.forward(global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(), body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(), transl=torch.from_numpy(smpl_trans).float(), scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(), ).joints.detach().numpy()

        vals = keypoints3d*10
        out_path="0328.npz"
        print('save frames', vals.shape[0], file_path)
        np.savez_compressed(out_path, joints_3d={"data": vals})

        nframes, njoints, _ = keypoints3d.shape
        # env_name = aist_dataset.mapping_seq2env[seq_name]
        # cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name)
        # keypoints2d = cgroup.project(keypoints3d)
        # keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

    elif FLAGS.mode == 'SMPLMesh':  # SMPL Mesh
        import trimesh  # install by `pip install trimesh`
        import vedo  # install by `pip install vedo`
        smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(aist_dataset.motion_dir, seq_name)
        smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
        vertices = smpl.forward(global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(), body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(), transl=torch.from_numpy(smpl_trans).float(), scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(), ).vertices.detach().numpy()[0]  # first frame
        faces = smpl.faces
        mesh = trimesh.Trimesh(vertices, faces)
        mesh.visual.face_colors = [200, 200, 250, 100]

        keypoints3d = AISTDataset.load_keypoint3d(aist_dataset.keypoint3d_dir, seq_name, use_optim=True)
        pts = vedo.Points(keypoints3d[0], r=20)  # first frame

        vedo.show(mesh, pts, interactive=True)
        exit()

if __name__ == '__main__':
    app.run(main)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法加油站

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值