【pytorch】Animatable 3D Gaussian+源码解读(二)

接上节,【pytorch】Animatable 3D Gaussian+源码解读(一),完成高斯球定义后,开始关注NeRFModel。

class NeRFModel(pl.LightningModule):
    def __init__(self, opt):
        super(NeRFModel, self).__init__()
        self.save_hyperparameters() # 储存init中输入的所有超参。后续访问可以由self.hparams.argX方式进行。同时,超参表也会被存到文件中
        self.model = hydra.utils.instantiate(opt.deformer) # 实例化上节分析的gala_model
        self.training_args = opt.training_args
        self.sh_degree = opt.max_sh_degree
        self.lambda_dssim = opt.lambda_dssim
        self.evaluator = Evaluator()
        if not os.path.exists("val"):
            os.makedirs("val")
        if not os.path.exists("test"):
            os.makedirs("test")

首先定义了一些训练超参数,
然后衡量指标:LPIPS、PSNR、SSIM
都封装在Evaluator类内

class Evaluator(nn.Module):
    """adapted from https://github.com/JanaldoChen/Anim-NeRF/blob/main/models/evaluator.py"""

    def __init__(self):
        super().__init__()
        self.lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex")
        self.psnr = PeakSignalNoiseRatio(data_range=1)
        self.ssim = StructuralSimilarityIndexMeasure(data_range=1)

    @custom_fwd(cast_inputs=torch.float32)
    def forward(self, rgb, rgb_gt):

        return {
            "psnr": self.psnr(rgb, rgb_gt),
            "ssim": self.ssim(rgb, rgb_gt),
            "lpips": self.lpips(rgb, rgb_gt),
        }

实例化:

model = NeRFModel(opt)

galaBasketball.py

datamodule = hydra.utils.instantiate(opt.dataset)

emm要不先回去找一下opt配置文件:

train.py中默认使用gala:

@hydra.main(config_path="./confs", config_name="gala", version_base="1.1")

找到confs/gala.yaml,指定dataset:gala/idle:

defaults:
  - _self_
  - dataset: gala/idle
  - deformer: gala_model

找到gala/idle.yaml,由此收获接下来所需要的一系列超参数:

_target_: animatableGaussian.dataset.galaBasketball.GalaBasketballDataModule
name: idle
num_workers: ${num_workers}
num_players: 1
opt:
  dataroot: ../../data/Gala/idle
  max_freq: ${max_freq}
  with_mask: False
  train:
    # provides parameters needed to split the train set
    camera_ids: 0,1,2,3,4,5
    start: 0
    end: 100
    skip: 2
    
  val:
    # provides parameters needed to split the val set
    camera_ids: 0
    start: 0
    end: 0
    skip: 2
    
  test:
    # provides parameters needed to split the test set
    camera_ids: 6
    start: 0
    end: 100
    skip: 2

那么,

datamodule = hydra.utils.instantiate(opt.dataset)

即实例化

_target_: animatableGaussian.dataset.galaBasketball.GalaBasketballDataModule
class GalaBasketballDataModule(pl.LightningDataModule):
    def __init__(self, num_workers, num_players, opt, train=True, **kwargs):
        super().__init__()
        # 准备做数据分割
        if train:
            splits = ["train", "val"]
        else:
            splits = ["test"]
        for split in splits:
            print(f"loading {split}set...")
            # 加载idle数据
            dataset = GalaBasketballDataset(
                opt.dataroot, num_players, opt.with_mask, opt.max_freq, split, opt.get(split))
            setattr(self, f"{split}set", dataset) # 设置对应属性
        self.num_workers = num_workers

	# 定义训练集、验证集、测试集数据加载器
    def train_dataloader(self):
        if hasattr(self, "trainset"):
            return DataLoader(self.trainset,
                              shuffle=True,
                              pin_memory=True,
                              batch_size=1,
                              persistent_workers=True,
                              num_workers=self.num_workers,
                              collate_fn=my_collate_fn)
        else:
            return super().train_dataloader()

    def val_dataloader(self):
        if hasattr(self, "valset"):
            return DataLoader(self.valset,
                              shuffle=False,
                              pin_memory=True,
                              batch_size=1,
                              persistent_workers=True,
                              num_workers=self.num_workers,
                              collate_fn=my_collate_fn)
        else:
            return super().val_dataloader()

    def test_dataloader(self):
        if hasattr(self, "testset"):
            return DataLoader(self.testset,
                              shuffle=False,
                              pin_memory=True,
                              batch_size=1,
                              persistent_workers=True,
                              num_workers=self.num_workers,
                              collate_fn=my_collate_fn)
        else:
            return super().test_dataloader()

dataloader就那样,所以还是把注意力放在dataset的处理上吧。

dataset

首先就是图像、相机、pose的一些数据处理:

class GalaBasketballDataset(torch.utils.data.Dataset):
    def __init__(self, dataroot, num_players, with_mask, max_freq, split, opt):
        self.split = split
        self.max_freq = max_freq
        self.opt = opt
        self.num_frames = (opt.end-opt.start)/opt.skip

        if type(opt.camera_ids) == str:
            self.camera_ids = list(map(int, opt.camera_ids.strip().split(',')))
            self.num_cameras = len(self.camera_ids)
        else:
            self.num_cameras = 1
            self.camera_ids = [opt.camera_ids]

        self.camera_params = getCamPara(dataroot, self.camera_ids, opt)
        self.image_width = self.camera_params[0].image_width
        self.image_height = self.camera_params[0].image_height
		
        self.imgs, self.masks = read_imgs(
            dataroot, with_mask, self.camera_ids, opt, (self.image_width, self.image_height))

        self.poses = load_pose(dataroot, num_players, opt, split)

代码清晰易懂,展开看一下相机、图像、姿态的处理细节:

getCamPara

通过getCamPara函数可以对应上(一)中描述相机参数文本文件的数据,

比如0.txt:

25,512,512 # property:fov(视场角) width height
6,4,0 # position
0.1530459,-0.6903456,0.1530459,0.6903456 # rotation
1,1,1 # scaling

根据opt,6个相机视角作为训练,1个相机视角作为测试。

def getCamPara(dataroot, camera_ids, opt):
    # load all camera and corresponding bg image

    camera_params = []
    for i in range(len(camera_ids)): # 依次处理6个不同视角的相机
        file_path = os.path.join(dataroot, "camera", str(camera_ids[i])+".txt")
        with open(file_path, 'r') as file:
            lines = file.readlines()
            line = lines[0].strip().split(',')
            property = list(map(float, line))
            line = lines[1].strip().split(',')
            position = list(map(float, line))
            line = lines[2].strip().split(',')
            rotation = list(map(float, line))
            line = lines[3].strip().split(',')
            scaling = list(map(float, line))

            fov = property[0]
            height = int(property[2])
            width = int(property[1])

            wh_ratio = property[1]/property[2]
            
            projmatrix = get_proj_mat(0.01, 1000, fov, wh_ratio) # 投影矩阵
            viewmatrix = matrix_TRS(position, rotation, scaling).inverse() # 视图矩阵
            
            # works well with unity coordinate system.
            viewmatrix[[0, 0], :] *= -1

调整一下相机的参数格式:

            # strictly follow the data format defined in Camera
            camera_param = Camera()
            camera_param.image_height = height
            camera_param.image_width = width
            camera_param.tanfovx = 1./projmatrix[0, 0]
            camera_param.tanfovy = 1./projmatrix[1, 1]

读取(一)中存疑的背景图:

            bg_path = os.path.join(dataroot, "bg", str(camera_ids[i])+".png")
            # Follow the following code exactly to read and process images
            bg = Image.open(bg_path)
            bg = PILtoTorch(bg, [width, height])

背景不在重建范围内,但渲染的时候要包括背景,所以直接用作camera_param.bg。

            # replace with the real bg
            camera_param.bg = bg
            camera_param.scale_modifier = 1.0
            camera_param.viewmatrix = viewmatrix.T
            camera_param.projmatrix = (projmatrix@viewmatrix).T
            camera_param.campos = torch.inverse(camera_param.viewmatrix)[3, :3]
            camera_params.append(camera_param)

read_imgs

从opt设置中可知:300帧中训练时使用第0-100帧(skip=2),即每个视角50帧
read_imgs将6个视角下的图片都放在imgs中,mask同理

def read_imgs(dataroot, with_mask, camera_ids, opt, resolution):
    # return [num_cameras, time] list of image Tensor

    imgs = []
    masks = []
    for i in range(len(camera_ids)):
        viewImg = []
        viewMask = []
        view_id = camera_ids[i]
        print(f"loading images from camera {view_id}")
        for t in tqdm(range(opt.start, opt.end, opt.skip)): 
            image_path = os.path.join(dataroot, str(
                view_id), str(t).zfill(4)+".png")
            # Follow the following code exactly to read and process images
            img = Image.open(image_path)
            img = PILtoTorch(img, resolution)
            viewImg.append(img[:3, ...])
            if with_mask:
                mask_path = os.path.join(dataroot, str(
                    view_id), str(t).zfill(4)+"_mask.png")
                mask = Image.open(mask_path)
                mask = PILtoTorch(mask, resolution)
                viewMask.append(mask[:3, ...])
        imgs.append(viewImg)
        if with_mask:
            masks.append(viewMask)
    # imgs = torch.stack(imgs)
    # masks = torch.stack(masks)

    return imgs, masks

load_pose

可知每帧对应的姿态文件中共有81行,即27个关节,每个关节有3种表示信息。

def read_pose(file_path):
    positions = []
    rotations = []
    scales = []

    with open(file_path, 'r') as file:
        lines = file.readlines()
        for i in range(0, len(lines), 3):
            # Read bone joint's position
            pos_line = lines[i].strip().split(',')
            positions.append(list(map(float, pos_line)))

            # Read bone joint's rotation
            rot_line = lines[i + 1].strip().split(',')
            rotations.append(list(map(float, rot_line)))

            # Read bone joint's scale
            scale_line = lines[i + 2].strip().split(',')
            scales.append(list(map(float, scale_line)))

    return torch.tensor(positions), torch.tensor(rotations), torch.tensor(scales)

如0000.txt中:

0.112642,0.9490079,-0.3583371 # position
2.23881,4.080895,358.956 # rotation
1,1,1 # scale
def load_pose(dataroot, num_players, opt, split):
    # return [num_cameras, time] list of ModelParams,strictly follow the data format defined in ModelParam
    poses = []
    for t in range(opt.start, opt.end, opt.skip):
        pose = ModelParam()
        file_path = os.path.join(dataroot, "pose", str(t).zfill(4)+".txt")
        positions, rotations, scales = read_pose(file_path) # 读取第t帧的姿态
        # 27个关节
        positions = positions.reshape([num_players, 27, 3]) 
        rotations = rotations.reshape([num_players, 27, 3])
        scales = scales.reshape([num_players, 27, 3])

        pose.body_pose = rotations[:, 1:]
        pose.global_orient = rotations[:, 0] # root orientation
        pose.transl = positions[:, 0] # root position
        # pose.scale = scales[0].unsqueeze(0)
        poses.append(pose)
    return poses

以上,完成了数据集关于相机、图像、姿态的处理和加载。
在这里插入图片描述

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值