nerf_factory源码解读(一)

文章介绍了nerf_factory框架的数据加载过程,特别是针对blender数据集,包括加载json配置文件、获取图像和相机参数。同时,详细解析了load_blender_data函数,以及相机姿态转换函数pose_spherical,涉及相机坐标系的旋转和平移操作。
摘要由CSDN通过智能技术生成
nerf_factory源码解读(一)

此项工作是KAIST的研究者发布的一个结构化的框架,继承了多个nerf模型的通用框架,有极强的复用性,因此我个人对这个代码进行了一些解读和理解。

数据集加载

nerf_factory框架提供了多个数据的加载接口,以适应多种nerf数据集的训练。如下图所示
在这里插入图片描述
首先是加载数据集部分,本文以加载blender数据集为例子

首先观察blender/chair的文件结构:

在这里插入图片描述

其中,testtrainval中存储的是图片,其余三个json文件存储的是相关的配置文件,具体内容如下:

{
    "camera_angle_x": 0.6911112070083618,
    "frames": [
        {
            "file_path": "./train/r_0",
            "rotation": 0.012566370614359171,
            "transform_matrix": [
                [
                    -0.9250140190124512,
                    0.2748899757862091,
                    -0.2622683644294739,
                    -1.0572376251220703
                ],
                [
                    -0.3799331784248352,
                    -0.6692678928375244,
                    0.6385383605957031,
                    2.5740303993225098
                ],
                [
                    0.0,
                    0.6903012990951538,
                    0.7235219478607178,
                    2.9166102409362793
                ],
                [
                    0.0,
                    0.0,
                    0.0,
                    1.0
                ]
            ]
        },

通过load_blender_data函数读取json文件中相关配置,并且得到了所有图像,pose以及渲染pose等一系列数据和参数,以下是对该函数的解读:

def load_blender_data(
    datadir: str,  # 数据存放文件夹
    scene_name: str, # 具体的场景
    train_skip: int,
    val_skip: int,
    test_skip: int,
    cam_scale_factor: float, # 缩放系数
    white_bkgd: bool,
):
    basedir = os.path.join(datadir, scene_name) 
    cam_trans = np.diag(np.array([1, -1, -1, 1], dtype=np.float32))  # 创建4x4的对角矩阵
    splits = ["train", "val", "test"]
    metas = {}
    for s in splits:
        with open(os.path.join(basedir, "transforms_{}.json".format(s)), "r") as fp:
            metas[s] = json.load(fp)  # 分别加载三个json文件,并保存在字典里

    images = []
    extrinsics = []
    counts = [0]

    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []

        if s == "train":
            skip = train_skip
        elif s == "val":
            skip = val_skip
        elif s == "test":
            skip = test_skip

        for frame in meta["frames"][::skip]:  # 以指定步长读取json里的frames
            fname = os.path.join(basedir, frame["file_path"] + ".png")
            imgs.append(imageio.imread(fname))  # 将所有图片存进imgs里
            poses.append(np.array(frame["transform_matrix"]))  # 将外参矩阵放进poses里
        imgs = (np.array(imgs) / 255.0).astype(np.float32)  # keep all 4 channels (RGBA)[N,H,W,4]
        poses = np.array(poses).astype(np.float32)  # [N,4,4]
        counts.append(counts[-1] + imgs.shape[0])  # 用来标定train val test之间的界限
        images.append(imgs)  # 把所有图片都加进images中
        extrinsics.append(poses)  # 把外参矩阵也加进extrinsics中

    i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]

    images = np.concatenate(images, 0)  # 把所有图片聚合成一个数组[N,H,W,4]

    extrinsics = np.concatenate(extrinsics, 0)  # 把所有外参矩阵聚合成一个数组[N,4,4]

    extrinsics[:, :3, 3] *= cam_scale_factor  # 尺度缩放因子
    extrinsics = extrinsics @ cam_trans  # Y轴 Z轴进行180°转向,从colmap相机坐标系转到nerf坐标系

    h, w = imgs[0].shape[:2]  
    num_frame = len(extrinsics)
    i_split += [np.arange(num_frame)]

    camera_angle_x = float(meta["camera_angle_x"])  # 水平视角
    focal = 0.5 * w / np.tan(0.5 * camera_angle_x)  # 计算焦距
    intrinsics = np.array(
        [
            [[focal, 0.0, 0.5 * w], [0.0, focal, 0.5 * h], [0.0, 0.0, 1.0]]
            for _ in range(num_frame)
        ]
    )  # 所有图像的内参矩阵K
    image_sizes = np.array([[h, w] for _ in range(num_frame)])

    render_poses = torch.stack(
        [
            pose_spherical(angle, -30.0, 4.0) @ cam_trans  # opencv2nerf
            for angle in np.linspace(-180, 180, 40 + 1)[:-1]
        ],
        0,
    )  # 生成用于渲染的c2w
    render_poses[:, :3, 3] *= cam_scale_factor
    near = 2.0
    far = 6.0

    if white_bkgd:
        images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
    else:
        images = images[..., :3]

    return (
        images,
        intrinsics,
        extrinsics,
        image_sizes,
        near,
        far,
        (-1, -1),
        i_split,
        render_poses,
    )

最后返回图像矩阵,外参,内参以及渲染位姿等参数。

其中要特别对render_pose单独说明

def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)  # 相机中心进行平移,这里是沿z轴移动(我的理解是这里是pytorch3D相机坐标系)
    c2w = rot_phi(phi / 180.0 * np.pi) @ c2w  # 绕x轴旋转
    c2w = rot_theta(theta / 180.0 * np.pi) @ c2w  # 绕y轴旋转
    c2w = (
        torch.tensor([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).float()
        @ c2w  # 相机坐标系和世界坐标系对齐,即X轴反向,y和z互换(这里我自己不太懂,为啥要进行这一步,对齐的是哪个世界坐标系?)
    )
    return c2w

具体不同相机坐标系定义见下图
在这里插入图片描述

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值