接上节,【pytorch】Animatable 3D Gaussian+源码解读(一),完成高斯球定义后,开始关注NeRFModel。
class NeRFModel(pl.LightningModule):
def __init__(self, opt):
super(NeRFModel, self).__init__()
self.save_hyperparameters() # 储存init中输入的所有超参。后续访问可以由self.hparams.argX方式进行。同时,超参表也会被存到文件中
self.model = hydra.utils.instantiate(opt.deformer) # 实例化上节分析的gala_model
self.training_args = opt.training_args
self.sh_degree = opt.max_sh_degree
self.lambda_dssim = opt.lambda_dssim
self.evaluator = Evaluator()
if not os.path.exists("val"):
os.makedirs("val")
if not os.path.exists("test"):
os.makedirs("test")
首先定义了一些训练超参数,
然后衡量指标:LPIPS、PSNR、SSIM
都封装在Evaluator类内
class Evaluator(nn.Module):
"""adapted from https://github.com/JanaldoChen/Anim-NeRF/blob/main/models/evaluator.py"""
def __init__(self):
super().__init__()
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex")
self.psnr = PeakSignalNoiseRatio(data_range=1)
self.ssim = StructuralSimilarityIndexMeasure(data_range=1)
@custom_fwd(cast_inputs=torch.float32)
def forward(self, rgb, rgb_gt):
return {
"psnr": self.psnr(rgb, rgb_gt),
"ssim": self.ssim(rgb, rgb_gt),
"lpips": self.lpips(rgb, rgb_gt),
}
实例化:
model = NeRFModel(opt)
galaBasketball.py
datamodule = hydra.utils.instantiate(opt.dataset)
emm要不先回去找一下opt配置文件:
train.py中默认使用gala:
@hydra.main(config_path="./confs", config_name="gala", version_base="1.1")
找到confs/gala.yaml,指定dataset:gala/idle:
defaults:
- _self_
- dataset: gala/idle
- deformer: gala_model
找到gala/idle.yaml,由此收获接下来所需要的一系列超参数:
_target_: animatableGaussian.dataset.galaBasketball.GalaBasketballDataModule
name: idle
num_workers: ${num_workers}
num_players: 1
opt:
dataroot: ../../data/Gala/idle
max_freq: ${max_freq}
with_mask: False
train:
# provides parameters needed to split the train set
camera_ids: 0,1,2,3,4,5
start: 0
end: 100
skip: 2
val:
# provides parameters needed to split the val set
camera_ids: 0
start: 0
end: 0
skip: 2
test:
# provides parameters needed to split the test set
camera_ids: 6
start: 0
end: 100
skip: 2
那么,
datamodule = hydra.utils.instantiate(opt.dataset)
即实例化
_target_: animatableGaussian.dataset.galaBasketball.GalaBasketballDataModule
class GalaBasketballDataModule(pl.LightningDataModule):
def __init__(self, num_workers, num_players, opt, train=True, **kwargs):
super().__init__()
# 准备做数据分割
if train:
splits = ["train", "val"]
else:
splits = ["test"]
for split in splits:
print(f"loading {split}set...")
# 加载idle数据
dataset = GalaBasketballDataset(
opt.dataroot, num_players, opt.with_mask, opt.max_freq, split, opt.get(split))
setattr(self, f"{split}set", dataset) # 设置对应属性
self.num_workers = num_workers
# 定义训练集、验证集、测试集数据加载器
def train_dataloader(self):
if hasattr(self, "trainset"):
return DataLoader(self.trainset,
shuffle=True,
pin_memory=True,
batch_size=1,
persistent_workers=True,
num_workers=self.num_workers,
collate_fn=my_collate_fn)
else:
return super().train_dataloader()
def val_dataloader(self):
if hasattr(self, "valset"):
return DataLoader(self.valset,
shuffle=False,
pin_memory=True,
batch_size=1,
persistent_workers=True,
num_workers=self.num_workers,
collate_fn=my_collate_fn)
else:
return super().val_dataloader()
def test_dataloader(self):
if hasattr(self, "testset"):
return DataLoader(self.testset,
shuffle=False,
pin_memory=True,
batch_size=1,
persistent_workers=True,
num_workers=self.num_workers,
collate_fn=my_collate_fn)
else:
return super().test_dataloader()
dataloader就那样,所以还是把注意力放在dataset的处理上吧。
dataset
首先就是图像、相机、pose的一些数据处理:
class GalaBasketballDataset(torch.utils.data.Dataset):
def __init__(self, dataroot, num_players, with_mask, max_freq, split, opt):
self.split = split
self.max_freq = max_freq
self.opt = opt
self.num_frames = (opt.end-opt.start)/opt.skip
if type(opt.camera_ids) == str:
self.camera_ids = list(map(int, opt.camera_ids.strip().split(',')))
self.num_cameras = len(self.camera_ids)
else:
self.num_cameras = 1
self.camera_ids = [opt.camera_ids]
self.camera_params = getCamPara(dataroot, self.camera_ids, opt)
self.image_width = self.camera_params[0].image_width
self.image_height = self.camera_params[0].image_height
self.imgs, self.masks = read_imgs(
dataroot, with_mask, self.camera_ids, opt, (self.image_width, self.image_height))
self.poses = load_pose(dataroot, num_players, opt, split)
代码清晰易懂,展开看一下相机、图像、姿态的处理细节:
getCamPara
通过getCamPara函数可以对应上(一)中描述相机参数文本文件的数据,
比如0.txt:
25,512,512 # property:fov(视场角) width height
6,4,0 # position
0.1530459,-0.6903456,0.1530459,0.6903456 # rotation
1,1,1 # scaling
根据opt,6个相机视角作为训练,1个相机视角作为测试。
def getCamPara(dataroot, camera_ids, opt):
# load all camera and corresponding bg image
camera_params = []
for i in range(len(camera_ids)): # 依次处理6个不同视角的相机
file_path = os.path.join(dataroot, "camera", str(camera_ids[i])+".txt")
with open(file_path, 'r') as file:
lines = file.readlines()
line = lines[0].strip().split(',')
property = list(map(float, line))
line = lines[1].strip().split(',')
position = list(map(float, line))
line = lines[2].strip().split(',')
rotation = list(map(float, line))
line = lines[3].strip().split(',')
scaling = list(map(float, line))
fov = property[0]
height = int(property[2])
width = int(property[1])
wh_ratio = property[1]/property[2]
projmatrix = get_proj_mat(0.01, 1000, fov, wh_ratio) # 投影矩阵
viewmatrix = matrix_TRS(position, rotation, scaling).inverse() # 视图矩阵
# works well with unity coordinate system.
viewmatrix[[0, 0], :] *= -1
调整一下相机的参数格式:
# strictly follow the data format defined in Camera
camera_param = Camera()
camera_param.image_height = height
camera_param.image_width = width
camera_param.tanfovx = 1./projmatrix[0, 0]
camera_param.tanfovy = 1./projmatrix[1, 1]
读取(一)中存疑的背景图:
bg_path = os.path.join(dataroot, "bg", str(camera_ids[i])+".png")
# Follow the following code exactly to read and process images
bg = Image.open(bg_path)
bg = PILtoTorch(bg, [width, height])
背景不在重建范围内,但渲染的时候要包括背景,所以直接用作camera_param.bg。
# replace with the real bg
camera_param.bg = bg
camera_param.scale_modifier = 1.0
camera_param.viewmatrix = viewmatrix.T
camera_param.projmatrix = (projmatrix@viewmatrix).T
camera_param.campos = torch.inverse(camera_param.viewmatrix)[3, :3]
camera_params.append(camera_param)
read_imgs
从opt设置中可知:300帧中训练时使用第0-100帧(skip=2),即每个视角50帧
read_imgs将6个视角下的图片都放在imgs中,mask同理
def read_imgs(dataroot, with_mask, camera_ids, opt, resolution):
# return [num_cameras, time] list of image Tensor
imgs = []
masks = []
for i in range(len(camera_ids)):
viewImg = []
viewMask = []
view_id = camera_ids[i]
print(f"loading images from camera {view_id}")
for t in tqdm(range(opt.start, opt.end, opt.skip)):
image_path = os.path.join(dataroot, str(
view_id), str(t).zfill(4)+".png")
# Follow the following code exactly to read and process images
img = Image.open(image_path)
img = PILtoTorch(img, resolution)
viewImg.append(img[:3, ...])
if with_mask:
mask_path = os.path.join(dataroot, str(
view_id), str(t).zfill(4)+"_mask.png")
mask = Image.open(mask_path)
mask = PILtoTorch(mask, resolution)
viewMask.append(mask[:3, ...])
imgs.append(viewImg)
if with_mask:
masks.append(viewMask)
# imgs = torch.stack(imgs)
# masks = torch.stack(masks)
return imgs, masks
load_pose
可知每帧对应的姿态文件中共有81行,即27个关节,每个关节有3种表示信息。
def read_pose(file_path):
positions = []
rotations = []
scales = []
with open(file_path, 'r') as file:
lines = file.readlines()
for i in range(0, len(lines), 3):
# Read bone joint's position
pos_line = lines[i].strip().split(',')
positions.append(list(map(float, pos_line)))
# Read bone joint's rotation
rot_line = lines[i + 1].strip().split(',')
rotations.append(list(map(float, rot_line)))
# Read bone joint's scale
scale_line = lines[i + 2].strip().split(',')
scales.append(list(map(float, scale_line)))
return torch.tensor(positions), torch.tensor(rotations), torch.tensor(scales)
如0000.txt中:
0.112642,0.9490079,-0.3583371 # position
2.23881,4.080895,358.956 # rotation
1,1,1 # scale
def load_pose(dataroot, num_players, opt, split):
# return [num_cameras, time] list of ModelParams,strictly follow the data format defined in ModelParam
poses = []
for t in range(opt.start, opt.end, opt.skip):
pose = ModelParam()
file_path = os.path.join(dataroot, "pose", str(t).zfill(4)+".txt")
positions, rotations, scales = read_pose(file_path) # 读取第t帧的姿态
# 27个关节
positions = positions.reshape([num_players, 27, 3])
rotations = rotations.reshape([num_players, 27, 3])
scales = scales.reshape([num_players, 27, 3])
pose.body_pose = rotations[:, 1:]
pose.global_orient = rotations[:, 0] # root orientation
pose.transl = positions[:, 0] # root position
# pose.scale = scales[0].unsqueeze(0)
poses.append(pose)
return poses
以上,完成了数据集关于相机、图像、姿态的处理和加载。