nlf 平滑注释

import argparse
import itertools
import os.path as osp
from typing import Optional

import cameravision
import more_itertools
import framepump
import numpy as np
import rlemasklib
import scipy.optimize
import simplepyutils as spu
import simplepyutils.argparse as spu_argparse

# 导入自定义模块所需的路径设置
current_dir = os.path.dirname(os.path.abspath(__file__))  # 获取当前文件路径
current_dir_p = os.path.dirname(current_dir)  # 获取上级目录路径
os.chdir(current_dir)  # 切换工作目录到当前文件所在路径
print('current_dir', current_dir)

# 添加项目相关路径到系统搜索路径
paths = [current_dir, current_dir_p]
paths.append('/shared_disk/users/lbg/project/nlf_train/NeuralLocalizerFields')  # 添加项目根路径

for path in paths:
    sys.path.insert(0, path)  # 将路径插入系统路径开头
    os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')  # 更新环境变量

# 导入人体模型相关模块
import smplfitter.np
import smplfitter.pt
import torch
import torch.nn.functional as F
from bodycompress import BodyDecompressor
from nlf_pipeline.util import smpl_mask
import os

# 设置数据路径环境变量
os.environ['INFERENCE_ROOT'] = '/shared_disk/users/lbg/data/nlf_val'  # 推理结果根路径
os.environ['DATA_ROOT'] = '/shared_disk/users/lbg/codes/nlf_torch/models'  # 模型数据根路径

# 从自定义模块导入路径配置
from nlf_pipeline.util.paths import DATA_ROOT, INFERENCE_ROOT
from simplepyutils import FLAGS

# 再次确认路径设置(可能用于覆盖默认值)
os.environ['INFERENCE_ROOT'] = '/shared_disk/users/lbg/data/nlf_val'
from nlf_pipeline.util import paths
paths.DATA_ROOT = '/shared_disk/users/lbg/codes/nlf_torch/models'
os.environ['DATA_ROOT'] = '/shared_disk/users/lbg/codes/nlf_torch/models'
DATA_ROOT = '/shared_disk/users/lbg/codes/nlf_torch/models'

# 导入图像处理库
import cv2
from glob import glob


def load_video_frames(video_path):
    """
    加载视频帧或图像序列
    :param video_path: 视频路径或图像目录路径
    :return: 帧列表和帧率
    """
    frames = []
    if os.path.isdir(video_path):  # 处理图像目录
        # 支持PNG和JPG格式
        files = glob(video_path + '/*.png') + glob(video_path + '/*.jpg')
        for file in files:
            img = cv2.imread(file)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # 转换为RGB格式
            frames.append(img)
        return frames, 25  # 默认帧率25fps for图像序列
    else:  # 处理视频文件
        video = cv2.VideoCapture(video_path)
        fps = int(video.get(cv2.CAP_PROP_FPS))  # 获取视频帧率
        count = video.get(cv2.CAP_PROP_FRAME_COUNT)  # 获取总帧数

        while True:
            ret, frame = video.read()
            if not ret:
                break
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))  # 转换颜色空间

        # 检查实际读取帧数与报告帧数是否一致
        if count != len(frames):
            print('video count err', count, len(frames))
        return frames, fps


def initialize():
    """
    初始化命令行参数解析
    """
    parser = argparse.ArgumentParser()
    # 视频相关参数
    parser.add_argument("--video-id", type=str, default='_28_0206_152203', help="视频ID")
    parser.add_argument("--suffix", type=str, default='', help="结果文件后缀")
    parser.add_argument("--fov", type=float, default=55, help="相机视场角(度)")
    
    # 处理选项
    parser.add_argument('--fill-gaps', action=spu_argparse.BoolAction, help="是否填充轨迹缺失帧")
    parser.add_argument('--skip-existing', action=spu_argparse.BoolAction, default=True, help="是否跳过已存在结果")
    parser.add_argument('--fps-factor', type=int, default=1, help="帧率缩放因子(用于插值)")
    
    spu.initialize(parser)  # 初始化simplepyutils配置
    torch.set_num_threads(12)  # 设置PyTorch线程数


def main():
    initialize()  # 初始化参数解析
    body_model_name = 'smpl'  # 使用SMPL人体模型
    fov = FLAGS.fov  # 获取视场角参数

    # 构建各数据文件路径
    pred_path = f'{INFERENCE_ROOT}/preds_np/{FLAGS.video_id}{FLAGS.suffix}.xz'  # 预测结果路径
    mask_path = f'{INFERENCE_ROOT}/masks/{FLAGS.video_id}_masks.pkl'  # 掩码路径
    camera_path = f'{INFERENCE_ROOT}/cameras/{FLAGS.video_id}.pkl'  # 相机参数路径

    # 处理帧率缩放相关路径
    dfps_str = f'_fps{FLAGS.fps_factor}' if FLAGS.fps_factor != 1 else ''
    video_path = f'{INFERENCE_ROOT}/videos_in/{FLAGS.video_id}.mp4'  # 输入视频路径

    # 平滑拟合结果路径
    fit_path = (
        f'{INFERENCE_ROOT}/smooth_fits/{FLAGS.video_id}{FLAGS.suffix}_smoothfits{dfps_str}.pkl'
    )
    camera_path_new = f'{INFERENCE_ROOT}/cameras/{FLAGS.video_id}{dfps_str}.pkl'  # 新相机参数路径
    ground_path = f'{INFERENCE_ROOT}/cameras/{FLAGS.video_id}{FLAGS.suffix}_ground.pkl'  # 地面参数路径

    # 跳过已存在结果
    if FLAGS.skip_existing and osp.exists(fit_path):
        return

    # 加载人体模型
    body_model = smplfitter.np.get_cached_body_model(body_model_name)
    video, fps = load_video_frames(video_path)  # 加载视频帧
    img_shape = (video[0].shape[:2])  # 获取图像尺寸(高, 宽)
    print('img_shape', img_shape)

    # 加载掩码数据
    masks = spu.load_pickle(mask_path)

    # 加载或初始化相机参数
    if osp.isfile(camera_path):
        cameras_display = spu.load_pickle(camera_path)
    else:
        # 使用固定视场角初始化相机参数
        cameras_display = [cameravision.Camera.from_fov(fov, img_shape)] * len(masks)

    n_verts = body_model.num_vertices  # 模型顶点数
    n_joints = body_model.num_joints  # 关节点数

    # 构建人体轨迹(通过掩码匹配预测结果)
    tracks = build_tracks_via_masks(pred_path, cameras_display, img_shape, masks, n_verts, n_joints)
    cam_points = np.stack([c.t for c in cameras_display], axis=0)  # 相机位置坐标

    # 对每个轨迹进行平滑拟合
    fit_tracks = [
        smooth_person_track(
            body_model_name,
            track[:, :n_verts, :3] / 1000,  # 顶点坐标(转换为米)
            track[:, n_verts:n_verts + n_joints, :3] / 1000,  # 关节坐标(转换为米)
            track[:, :n_verts, 3],  # 顶点不确定性
            track[:, n_verts:n_verts + n_joints, 3],  # 关节不确定性
            cam_points / 1000,  # 相机位置(转换为米)
            fps,
            n_verts,
            n_joints,
        )
        for track in spu.progressbar(tracks, desc='Fitting tracks')
    ]

    # 收集每帧的有效拟合结果
    valid_fits_per_frame = collect_valid_fits_per_frame(fit_tracks)
    spu.dump_pickle(valid_fits_per_frame, fit_path)  # 保存拟合结果

    # 拟合地面平面和向上向量
    ground_height, new_up = fit_ground_plane(
        body_model, valid_fits_per_frame, cameras_display[0].world_up
    )
    spu.dump_pickle(dict(ground_height=ground_height, world_up=new_up), ground_path)  # 保存地面参数

    # 处理帧率缩放(插值相机参数)
    if FLAGS.fps_factor != 1:
        cameras_display = interp_fn(cameras_display)  # 插值函数生成中间帧相机参数
        spu.dump_pickle(cameras_display, camera_path_new)


def collect_valid_fits_per_frame(fit_tracks):
    """
    收集每帧的有效拟合数据(过滤NaN)
    :param fit_tracks: 各轨迹的拟合结果
    :return: 每帧有效数据列表
    """
    fits_per_frame = []
    # 获取数据形状
    pose_rotvecs_shape = fit_tracks[0]['pose_rotvecs'].shape[1:]
    shape_betas_shape = fit_tracks[0]['shape_betas'].shape[1:]
    trans_shape = fit_tracks[0]['trans'].shape[1:]

    for i_frame in range(fit_tracks[0]['pose_rotvecs'].shape[0]):
        pose_rotvecs = []
        shape_betas = []
        trans = []
        for track in fit_tracks:
            # 检查数据有效性(非NaN)
            if np.all(np.isfinite(track['pose_rotvecs'][i_frame])):
                pose_rotvecs.append(track['pose_rotvecs'][i_frame])
                shape_betas.append(track['shape_betas'][i_frame])
                trans.append(track['trans'][i_frame])
        # 堆叠有效数据
        fits_per_frame.append(
            dict(
                pose_rotvecs=stack(pose_rotvecs, pose_rotvecs_shape),
                shape_betas=stack(shape_betas, shape_betas_shape),
                trans=stack(trans, trans_shape),
            )
        )
    return fits_per_frame


def interp_cam(cam1, cam2, t):
    """
    插值两个相机参数
    :param cam1: 相机1参数
    :param cam2: 相机2参数
    :param t: 插值因子(0-1)
    :return: 插值后的相机参数
    """
    # 插值焦距和光心(对数空间线性插值)
    f1 = np.array([cam1.intrinsic_matrix[0, 0], cam1.intrinsic_matrix[1, 1]])
    f2 = np.array([cam2.intrinsic_matrix[0, 0], cam2.intrinsic_matrix[1, 1]])
    c1 = cam1.intrinsic_matrix[:2, 2]
    c2 = cam2.intrinsic_matrix[:2, 2]
    f = np.exp(np.log(f1) + t * (np.log(f2) - np.log(f1)))
    c = np.exp(np.log(c1) + t * (np.log(c2) - np.log(c1)))
    intr = np.array([[f[0], 0, c[0]], [0, f[1], c[1]], [0, 0, 1]])  # 内参矩阵

    # 插值畸变系数和外参
    dist1 = cam1.get_distortion_coeffs()
    dist2 = cam2.get_distortion_coeffs()
    dist = dist1 + t * (dist2 - dist1)  # 线性插值畸变参数

    optical_center = cam1.t + t * (cam2.t - cam1.t)  # 线性插值相机位置
    rot = project_to_SO3(cam1.R + t * (cam2.R - cam1.R))  # 投影到SO(3)群保证正交性

    return cameravision.Camera(
        rot_world_to_cam=rot,
        intrinsic_matrix=intr,
        optical_center=optical_center,
        distortion_coeffs=dist,
        world_up=cam1.world_up,
    )


def project_to_SO3(A):
    """
    将矩阵投影到特殊正交群SO(3)
    :param A: 输入矩阵
    :return: 正交矩阵
    """
    U, _, Vh = np.linalg.svd(A)
    T = U @ Vh
    # 处理反射情况(行列式为负时修正)
    has_reflection = (np.linalg.det(T) < 0)[..., np.newaxis, np.newaxis]
    T_mirror = T - 2 * U[..., -1:] @ Vh[..., -1:, :]
    return np.where(has_reflection, T_mirror, T)


def stack(a, element_shape):
    """
    堆叠数据并处理空列表情况
    :param a: 数据列表
    :param element_shape: 元素形状
    :return: 堆叠后的数组
    """
    if len(a) == 0:
        return np.zeros((0,) + element_shape, np.float32)
    return np.stack(a, axis=0)


# 以下为PyTorch相关的平滑滤波函数(略去重复注释,主要功能为时间序列平滑处理)
def conv1d_indep(a: torch.Tensor, kernel: torch.Tensor):
    return F.conv1d(a, kernel[np.newaxis, np.newaxis], padding='same')


def moving_mean(a: torch.Tensor, weights: torch.Tensor, kernel: torch.Tensor):
    finite = torch.all(
        torch.logical_and(torch.isfinite(a), torch.isfinite(weights)), dim=1, keepdim=True
    )
    a = torch.where(finite, a, torch.zeros_like(a))
    weights = torch.where(finite, weights, torch.zeros_like(weights))
    return torch.nan_to_num(conv1d_indep(weights * a, kernel) / conv1d_indep(weights, kernel))


def moving_mean_dim(x: torch.Tensor, weights: torch.Tensor, kernel: torch.Tensor, dim: int = -2):
    weights = torch.broadcast_to(weights, x.shape)
    x = x.movedim(dim, -1)
    weights = weights.movedim(dim, -1)
    mean = moving_mean(
        x.reshape(-1, 1, x.shape[-1]), weights.reshape(-1, 1, weights.shape[-1]), kernel
    ).reshape(x.shape)
    return mean.movedim(-1, dim)


@torch.jit.script
def robust_geometric_filter(
    x: torch.Tensor,
    w: Optional[torch.Tensor],
    kernel: torch.Tensor,
    dim: int = -2,
    eps: float = 1e-1,
    n_iter: int = 10,
):
    # 鲁棒几何滤波(迭代加权平均)
    w_ = torch.ones_like(x[..., :1]) if w is None else w.unsqueeze(-1)
    x = torch.where(torch.isfinite(x), x, torch.zeros_like(x))
    y = moving_mean_dim(x, w_, kernel, dim=dim)
    for _ in range(n_iter):
        dist = torch.norm(x - y, dim=-1, keepdim=True)
        w_modified = w_ / (dist + eps)
        y = moving_mean_dim(x, w_modified, kernel, dim=dim)
    w2 = moving_mean_dim(w_, w_modified, kernel, dim=dim)
    return y, w2.squeeze(-1)


def robust_geometric_filter_twosided(
    x,
    w,
    kernel,
    dim_t=-3,
    dim_n=None,
    eps=1e-1,
    n_iter=10,
    split_threshold=1,
    split_indices=None,
):
    # 双向鲁棒滤波(处理运动分割点)
    if dim_t < 0:
        dim_t = len(x.shape) + dim_t
    if split_indices is None:
        # 自动检测分割点(基于局部最大值)
        kernel_half_size = kernel.shape[0] // 2
        center_half = kernel[kernel_half_size : kernel_half_size + 1] / 2
        left_half_kernel = torch.cat(
            [kernel[:kernel_half_size], center_half, torch.zeros_like(kernel[kernel_half_size:])],
            dim=0,
        )
        right_half_kernel = torch.cat(
            [
                torch.zeros_like(kernel[:-kernel_half_size]),
                center_half,
                kernel[-kernel_half_size:],
            ],
            dim=0,
        )
        left_filtered = robust_geometric_filter(x, w, left_half_kernel, dim=dim_t, eps=eps, n_iter=n_iter)[0]
        right_filtered = robust_geometric_filter(x, w, right_half_kernel, dim=dim_t, eps=eps, n_iter=n_iter)[0]
        d = torch.norm(left_filtered - right_filtered, dim=-1)
        if dim_n is not None:
            d = d.min(dim=dim_n, keepdim=True)
        # 检测局部最大值作为分割点
        midslice = tuple([slice(1, -1) if i == dim_t else slice(None) for i in range(len(d.shape))])
        leftslice = tuple([slice(0, -2) if i == dim_t else slice(None) for i in range(len(d.shape))])
        rightslice = tuple([slice(2, None) if i == dim_t else slice(None) for i in range(len(d.shape))])
        tiebreaker_noise = torch.randn_like(d)
        noisy_diff = d + 1e-3 * split_threshold * tiebreaker_noise
        is_local_max = torch.logical_and(
            noisy_diff[midslice] > noisy_diff[leftslice],
            noisy_diff[midslice] > noisy_diff[rightslice],
        )
        pad = torch.zeros_like(torch.index_select(is_local_max, dim_t, torch.tensor([0], dtype=torch.int64)))
        is_local_max = torch.cat([pad, is_local_max, pad], dim=dim_t)
        is_split_point = torch.logical_and(is_local_max, d > split_threshold)
        split_indices = torch.where(is_split_point)[0] + 1
    # 分块处理
    ragged_x = torch.tensor_split(x, split_indices, dim=dim_t)
    ragged_w = torch.tensor_split(w, split_indices, dim=dim_t)
    filtered = []
    filtered_weights = []
    for _x, _w in zip(ragged_x, ragged_w):
        filtered_x, filtered_w = robust_geometric_filter(_x, _w, kernel, dim=dim_t, eps=eps, n_iter=n_iter)
        filtered.append(filtered_x)
        filtered_weights.append(filtered_w)
    return torch.cat(filtered, dim=dim_t), torch.cat(filtered_weights, dim=dim_t), split_indices


def apply_nanmask(values, weights):
    # 处理NaN值(用0填充并调整权重)
    if values.dim() > weights.dim():
        values, weights = apply_nanmask(values, weights.unsqueeze(-1))
        return values, weights.squeeze(-1)
    finite = torch.all(torch.logical_and(torch.isfinite(values), torch.isfinite(weights)), dim=-1, keepdim=True)
    values = torch.where(finite, values, torch.zeros_like(values))
    weights = torch.where(finite, weights, torch.zeros_like(weights))
    return values, weights


def reduce_nanmean(x, dim, keepdim=False):
    # 计算非NaN值的均值
    is_finite = torch.isfinite(x)
    x = torch.where(is_finite, x, torch.zeros_like(x))
    return torch.sum(x, dim=dim, keepdim=keepdim) / torch.sum(is_finite, dim=dim, keepdim=keepdim)


def smooth_person_track(
    body_model_name,
    vertices,
    joints,
    vertices_uncert,
    joints_uncert,
    cam_points,
    fps,
    n_verts,
    n_joints,
    n_subset=1024,
):
    """
    平滑人体轨迹(包括顶点和关节的时间滤波)
    :param body_model_name: 人体模型名称
    :param vertices: 顶点坐标
    :param joints: 关节坐标
    :param vertices_uncert: 顶点不确定性
    :param joints_uncert: 关节不确定性
    :param cam_points: 相机位置
    :param fps: 帧率
    :param n_verts: 顶点数
    :param n_joints: 关节数
    :param n_subset: 顶点子集大小(用于加速计算)
    :return: 平滑后的轨迹数据
    """
    # 转换为PyTorch张量
    vertices = torch.as_tensor(vertices, dtype=torch.float32)
    joints = torch.as_tensor(joints, dtype=torch.float32)
    vertices_uncert = torch.as_tensor(vertices_uncert, dtype=torch.float32)
    joints_uncert = torch.as_tensor(joints_uncert, dtype=torch.float32)
    cam_points = torch.as_tensor(cam_points, dtype=torch.float32)

    # 计算权重(基于不确定性的倒数)
    def get_weights(uncerts, exponent):
        w = uncerts**-exponent
        return w / reduce_nanmean(w, dim=(-2, -1), keepdim=True)  # 归一化权重

    joint_weights = get_weights(joints_uncert, 1.5)
    vertex_weights = get_weights(vertices_uncert, 1.5)

    # 应用NaN掩码
    vertices, vertex_weights = apply_nanmask(vertices, vertex_weights)
    joints, joint_weights = apply_nanmask(joints, joint_weights)

    # 选择顶点子集(加速计算)
    if n_subset < n_verts:
        vertex_subset_np = np.load(f'{DATA_ROOT}/body_models/smpl/vertex_subset_{n_subset}.npz')['i_verts']
    else:
        vertex_subset_np = np.arange(n_verts)
    vertex_subset_pt = torch.from_numpy(vertex_subset_np)

    # 初步拟合获取尺度校正
    fits_prelim = smplfitter.pt.get_cached_fit_fn(
        body_model_name=body_model_name,
        requested_keys=('vertices', 'joints', 'pose_rotvecs'),
        beta_regularizer=10 * ((n_subset + n_joints) / (n_verts + n_joints)),  # 正则化系数
        beta_regularizer2=0.2 * ((n_subset + n_joints) / (n_verts + n_joints)),
        scale_regularizer=1 * ((n_subset + n_joints) / (n_verts + n_joints)),
        num_betas=10,
        vertex_subset=tuple(vertex_subset_np),
        share_beta=True,
        final_adjust_rots=False,
        scale_target=True,
        device='cpu',
    )(
        torch.index_select(vertices, dim=1, index=vertex_subset_pt),
        joints,
        torch.index_select(vertex_weights, dim=1, index=vertex_subset_pt),
        joint_weights,
    )

    # 尺度校正(基于相机位置)
    scale = fits_prelim['scale_corr']
    scale /= torch.nanmedian(scale, dim=0).values  # 归一化尺度

    def scale_from_cam(scene_points, cam_points, scale):
        cam_points = cam_points[:, np.newaxis]
        return (scene_points - cam_points) * scale[:, np.newaxis] + cam_points  # 尺度变换

    vertices = scale_from_cam(vertices, cam_points, scale[:, np.newaxis])
    joints = scale_from_cam(joints, cam_points, scale[:, np.newaxis])

    # 时间平滑(根关节距离滤波)
    if fps < 40:
        kernel_large = torch.tensor([0.01, 0.05, 0.1, 1, 2, 6, 2, 1, 0.1, 0.05, 0.01], dtype=torch.float32)
    else:
        kernel_large = torch.tensor([0.01, 0.02, 0.05, 0.1, 0.2, 1, 1.3, 2, 3.3, 6, 3.3, 2, 1.3, 1, 0.2, 0.1, 0.05, 0.02, 0.01], dtype=torch.float32)

    filtered_root, _, split_indices = robust_geometric_filter_twosided(
        joints[:, 0], joint_weights[:, 0], kernel_large, dim_t=-2, eps=5e-2, split_threshold=1
    )

    # 计算尺度因子(基于根关节距离)
    root_dist_sq = torch.sum(torch.square(joints[:, 0] - cam_points), dim=-1, keepdim=True)
    filtered_root_distance = torch.sum((filtered_root - cam_points) * (joints[:, 0] - cam_points), dim=-1, keepdim=True)
    scales = filtered_root_distance / root_dist_sq
    vertices = scale_from_cam(vertices, cam_points, scales)
    joints = scale_from_cam(joints, cam_points, scales)

    # 顶点和关节的时间平滑
    if fps < 40:
        kernel_small = torch.tensor([1, 3, 12, 3, 1], dtype=torch.float32)
    else:
        kernel_small = torch.tensor([1, 1.5, 3, 6, 12, 6, 3, 1.5, 1], dtype=torch.float32)
    points = torch.cat([vertices, joints], dim=-2)
    weights = torch.cat([vertex_weights, joint_weights], dim=-1)
    points, new_weights, _ = robust_geometric_filter_twosided(
        points, weights, kernel_small, dim_t=-3, dim_n=-2, eps=5e-2, split_threshold=1, split_indices=split_indices
    )
    vertices, joints = torch.split(points, [n_verts, n_joints], dim=-2)
    vertex_weights, joint_weights = torch.split(new_weights, [n_verts, n_joints], dim=-1)

    # 帧率缩放(插值生成中间帧)
    if FLAGS.fps_factor != 1:
        def interp_fun(vals):
            mids = 0.5 * (vals[:-1] + vals[1:])
            mids = torch.cat([mids, vals[-1:]], dim=0)
            return torch.reshape(torch.stack([vals, mids], dim=1), [-1, *vals.shape[1:]])  # 生成插值帧

        vertices = interp_fun(vertices)
        joints = interp_fun(joints)
        vertex_weights = interp_fun(vertex_weights)
        joint_weights = interp_fun(joint_weights)
        split_indices = split_indices * FLAGS.fps_factor
        cam_points = interp_fun(cam_points)

    # 最终拟合(不估计尺度,调整旋转)
    fits = smplfitter.pt.get_cached_fit_fn(
        body_model_name=body_model_name,
        requested_keys=('vertices', 'joints', 'pose_rotvecs'),
        beta_regularizer=10 * ((n_subset + n_joints) / (n_verts + n_joints)),
        beta_regularizer2=0,
        num_betas=10,
        vertex_subset=tuple(vertex_subset_np),
        share_beta=True,
        final_adjust_rots=True,
        device='cpu',
    )(
        torch.index_select(vertices, dim=1, index=vertex_subset_pt),
        joints,
        torch.index_select(vertex_weights, dim=1, index=vertex_subset_pt),
        joint_weights,
    )

    # 获取模型正向结果
    body_model = smplfitter.pt.get_cached_body_model(body_model_name)
    fit_res = body_model.forward(fits['pose_rotvecs'], fits['shape_betas'], fits['trans'])
    vertices = fit_res['vertices']
    joints = fit_res['joints']

    # 根关节平移平滑
    if fps * FLAGS.fps_factor < 40:
        kernel = torch.tensor([1, 2, 3, 2, 1], dtype=torch.float32)
    else:
        kernel = torch.tensor([1, 1.5, 2, 2.5, 3, 2.5, 2, 1.5, 1], dtype=torch.float32)
    filtered_root = robust_geometric_filter_twosided(
        joints[:, 0], joint_weights[:, 0], kernel, dim_t=-2, eps=5e-2, split_indices=split_indices, split_threshold=1
    )[0]

    # 计算平移偏移
    root_dist_sq = torch.sum(torch.square(joints[:, 0] - cam_points), dim=-1, keepdim=True)
    filtered_root_distance = torch.sum((filtered_root - cam_points) * (joints[:, 0] - cam_points), dim=-1, keepdim=True)
    scales = (filtered_root_distance / root_dist_sq)[..., np.newaxis]
    offset = (scales - 1) * joints[:, :1]

    # 应用偏移
    vertices = vertices + offset
    joints = joints + offset
    fits['trans'] = fits['trans'] + torch.squeeze(offset, dim=-2)

    # 标记无效帧(全NaN)
    is_invalid = torch.all(joint_weights == 0, dim=1, keepdim=True)[..., np.newaxis]
    vertices = torch.where(is_invalid, torch.nan, vertices)
    joints = torch.where(is_invalid, torch.nan, joints)
    is_invalid = torch.squeeze(is_invalid, dim=-1)
    pose_rotvecs = torch.where(is_invalid, torch.nan, fits['pose_rotvecs'])
    shape_betas = torch.where(is_invalid, torch.nan, fits['shape_betas'])
    trans = torch.where(is_invalid, torch.nan, fits['trans'])

    return dict(
        vertices=vertices.numpy(),
        joints=joints.numpy(),
        pose_rotvecs=pose_rotvecs.numpy(),
        shape_betas=shape_betas.numpy(),
        trans=trans.numpy(),
    )


def fill_nan_with_prev_nonnan(a, axis):
    """
    用前一有效帧填充NaN(沿指定轴)
    :param a: 输入数组
    :param axis: 填充轴
    """
    prev = None
    for item in iterdim(a, axis):
        isnan = np.isnan(item)
        if prev is not None:
            item[isnan] = prev[isnan]  # 用前一帧非NaN值填充
        prev = item.copy()


def iterdim(a, axis=0):
    """
    沿指定轴迭代数组元素
    :param a: 输入数组
    :param axis: 迭代轴
    :yield: 轴上的每个元素
    """
    a = np.asarray(a)
    leading_indices = (slice(None),) * axis
    for i in range(a.shape[axis]):
        yield a[leading_indices + (i,)]


def build_tracks_via_masks(pred_path, cameras_display, video_imshape, masks, n_points, n_verts, n_coords=4, iou_threshold=0):
    """
    通过掩码匹配构建人体轨迹
    :param pred_path: 预测结果路径
    :param cameras_display: 相机参数列表
    :param video_imshape: 视频图像尺寸
    :param masks: 掩码列表
    :param n_points: 总点数(顶点+关节)
    :param n_verts: 顶点数
    :param n_coords: 坐标维度(含不确定性)
    :param iou_threshold: IOU匹配阈值
    :return: 轨迹数组
    """
    preds_per_frame = []
    pred_reader = BodyDecompressor(pred_path)  # 解压缩预测结果

    for i_frame, (d, camera_display) in enumerate(
        spu.progressbar(
            zip(pred_reader, cameras_display),
            total=len(masks),
            desc='Matching meshes to masks',
        )
    ):
        # 合并坐标和不确定性
        points = np.concatenate([d['vertices'], d['joints']], axis=1)
        uncerts = np.concatenate([d['vertex_uncertainties'], d['joint_uncertainties']], axis=1)
        points_and_uncerts = np.concatenate([points, uncerts[..., np.newaxis]], axis=-1)

        # 关联预测结果和掩码(基于IOU匹配)
        ordered_preds = associate_predictions_to_masks_mesh(
            poses3d_pred=points_and_uncerts,
            frame_shape=video_imshape,
            masks=masks[i_frame],
            camera=camera_display,
            n_points=n_points,
            n_verts=n_verts,
            n_coords=n_coords,
            iou_threshold=iou_threshold,
        )
        preds_per_frame.append(ordered_preds)

    tracks = np.stack(preds_per_frame, axis=1)  # 堆叠为轨迹数组

    # 填充缺失帧(可选)
    if FLAGS.fill_gaps:
        fill_nan_with_prev_nonnan(tracks, axis=1)
    return tracks


def associate_predictions_to_masks_mesh(
    poses3d_pred, frame_shape, masks, camera, n_points, n_verts, n_coords=4, iou_threshold=0
):
    """
    将预测网格与掩码关联(基于IOU匹配)
    :param poses3d_pred: 预测的3D姿态(含坐标和不确定性)
    :param frame_shape: 图像尺寸
    :param masks: 掩码列表
    :param camera: 相机参数
    :param n_points: 总点数
    :param n_verts: 顶点数
    :param n_coords: 坐标维度
    :param iou_threshold: IOU阈值
    :return: 关联后的结果数组
    """
    n_true_poses = len(masks)
    result = np.full((n_true_poses, n_points, n_coords), np.nan, dtype=np.float32)
    if n_true_poses == 0:
        return result

    # 转换掩码为RLE格式
    mask_rles = [rlemasklib.RLEMask.from_dict(m) for m in masks]
    mask_shape = mask_rles[0].shape

    # 调整相机参数以匹配掩码尺寸
    camera_rescaled = camera.scale_output(
        [mask_shape[1] / frame_shape[1], mask_shape[0] / frame_shape[0]], inplace=False
    )

    # 渲染预测网格的RLE掩码
    pose_rles = smpl_mask.render_rle(
        poses3d_pred[:, :n_verts, :3], camera_rescaled, mask_shape, 1024
    )

    # 计算IOU矩阵并匹配(匈牙利算法)
    iou_matrix = rlemasklib.RLEMask.iou_matrix(mask_rles, pose_rles)
    true_indices, pred_indices = scipy.optimize.linear_sum_assignment(-iou_matrix)  # 最大化IOU

    for ti, pi in zip(true_indices, pred_indices):
        if iou_matrix[ti, pi] >= iou_threshold:
            result[ti] = poses3d_pred[pi]  # 分配匹配结果
    return result


def find_up_vector(points, almost_up=(0, -1, 0), thresh_degrees=60):
    """
    估计向上向量(基于平面拟合)
    :param points: 3D点云(用于拟合平面)
    :param almost_up: 先验向上向量
    :param thresh_degrees: 角度阈值(超过则使用先验)
    :return: 估计的向上向量
    """
    almost_up = np.asarray(almost_up, np.float32)
    import pyransac3d

    # RANSAC拟合平面
    plane1 = pyransac3d.Plane()
    _, best_inliers = plane1.fit(points, thresh=25, maxIteration=5000)
    if len(best_inliers) < 3:
        raise ValueError('Could not fit a plane to the points, too few inliers')

    # 获取平面法向量作为向上向量
    world_up = np.asarray(fit_plane(points[best_inliers]), np.float32)
    if np.dot(world_up, almost_up) < 0:
        world_up *= -1  # 确保与先验方向一致

    # 检查角度是否在阈值内
    angle = np.rad2deg(np.arccos(np.dot(world_up, almost_up)))
    if angle > thresh_degrees:
        world_up = almost_up  # 超过阈值则使用先验

    return np.array(world_up, np.float32)


def fit_plane(points):
    """
    最小二乘拟合平面(返回法向量)
    :param points: 3D点云
    :return: 平面法向量
    """
    points = np.asarray(points, np.float32)
    x = points - np.mean(points, axis=0, keepdims=True)  # 去均值
    u, s, vt = np.linalg.svd(x, full_matrices=False)  # SVD分解
    return vt[-1]  # 最小奇异值对应的右奇异向量为法向量


if __name__ == '__main__':
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值