Northwestern-UCLA数据处理成AGCN网络可用的输入数据

import argparse
import pickle
from tqdm import tqdm
import sys

sys.path.extend(['../'])
from data_gen.preprocess import pre_normalization


training_cameras = ['1', '2']
max_body_true = 1
max_body = 2
num_joint = 20
max_frame = 202

import numpy as np
import os


def get_nonzero_std(s):  # tvc
    index = s.sum(-1).sum(-1) != 0  # select valid frames
    s = s[index]
    if len(s) != 0:
        s = s[:, :, 0].std() + s[:, :, 1].std() + s[:, :, 2].std()  # three channels
    else:
        s = 0
    return s


def read_xyz(data_path, max_body=2, num_joint=25):  # 取了前两个body
    data = np.zeros((max_body, max_frame, num_joint, 3))
    frame_num = 0
    per_dict = {}
    per_idx = 0
    for i, file_name in enumerate(sorted(os.listdir(data_path))):
        if file_name.endswith("_skeletons.txt"):
            ske_path = os.path.join(data_path, file_name)
            # print(ske_path)
            with open(ske_path, 'r') as sf:
                while sf:
                    person_id = sf.readline()
                    if person_id == '':
                        break
                    if per_dict.get(person_id) == None:
                        per_dict[person_id] = per_idx
                        per_idx += 1
                    # if len(per_dict.keys()) > 1:
                    #     print(ske_path)
                    #     print(per_dict)
                    for j in range(20):
                        sline = sf.readline().split(",")
                        x, y, z = float(sline[0]), float(sline[1]), float(sline[2])
                        data[per_dict.get(person_id), frame_num, j, :] = [x, y, z] 
            frame_num += 1
    # 取出可能性最大的一个body
    energy = np.array([get_nonzero_std(x) for x in data])
    index = energy.argsort()[::-1][0:max_body_true]
    data = data[index]
    data = data.transpose(3, 1, 2, 0)  # 转为 通道数、帧数、关节点数、人数
    return data


def rename_cls(cls):
    """
    由于7和10没有动作,所以需要对序号进行重新连续编号
    """
    if cls in [1, 2, 3, 4, 5, 6]:
        return cls
    elif cls in [8, 9]:
        return cls - 1
    elif cls in [11, 12]:
        return cls - 2
    else:
        raise ValueError()

def gendata(data_path, out_path, views, part):
    """
    data_path: 原始的数据路径
    out_path: 输出文件的路径
    views: 包含的视图
    part: 当前正在处理训练集还是测试集
    """
    sample_name = []
    sample_label = []
    for view in views:
        view_data_path = os.path.join(data_path, view)  # 获取到view的文件路径
        for filename in os.listdir(view_data_path):  # 该view下的所有文件夹进行遍历
            action_class = rename_cls(int(filename[filename.find('a') + 1:filename.find('a') + 3]))   # 获得类别ID
            
            istraining = (view[-1] in training_cameras)  # 如果参与者的ID在训练列表里,则改文件为验证集上的
            
            if part == 'train':
                issample = istraining
            elif part == 'val':
                issample = not (istraining)
            else:
                raise ValueError()

            if issample:
                sample_name.append(view + "-" + filename)  # 获得文件名
                sample_label.append(action_class - 1)  # 获得类别,这里从0开始,因此需要减1

    with open('{}/{}_label.pkl'.format(out_path, part), 'wb') as f:
        pickle.dump((sample_name, list(sample_label)), f)

    fp = np.zeros((len(sample_label), 3, max_frame, num_joint, max_body_true), dtype=np.float32)

    for i, s in enumerate(tqdm(sample_name)):
        # 读取xyz数据,即获得三维的人体关节点数据,返回的是[通道数、帧数、关节点数、人数]
        data = read_xyz(os.path.join(data_path, s.split("-")[0], s.split("-")[1]), max_body=max_body, num_joint=num_joint)
        fp[i, :, 0:data.shape[1], :, :] = data

    fp = pre_normalization(fp)
    np.save('{}/{}_data_joint.npy'.format(out_path, part), fp)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='NTU-RGB-D Data Converter.')
    parser.add_argument('--data_path', default='../data/ucla_raw/ucla_multiview_action/')
    parser.add_argument('--out_folder', default='../data/ucla/')
    views = ['view_1', 'view_2', 'view_3']
    part = ['train', 'val']
    arg = parser.parse_args()
    
    for p in part:  # 训练集和验证集
        out_path = arg.out_folder   # 获得输出路径
        if not os.path.exists(out_path):  # 如果不存在该路径,则创建该文件夹
            os.makedirs(out_path)
        print("doing=", p)
        # 获取数据
        data_path = arg.data_path
        gendata(data_path, out_path, views, p)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值