3DGS中的优化方案
由于初始化点云可能导致生成高斯在空间中密度过大或过小,3dgs给出一些手段来在学习过程中自适应地调控密度,具体方法有点密集化和点剪枝。
点密集化
重建过度的区域拆分大高斯
官方代码
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
n_init_points = self.get_xyz.shape[0]
# Extract points that satisfy the gradient condition
padded_grad = torch.zeros((n_init_points), device="cuda")
padded_grad[:grads.shape[0]] = grads.squeeze()
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
selected_pts_mask = torch.logical_and(selected_pts_mask,
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
means =torch.zeros((stds.size(0), 3),device="cuda")
samples = torch.normal(mean=means, std=stds)
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
# 新生成点云信息(由大高斯分割并按一定系数缩放得到)
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
# 裁剪过程参照其不透明度值进行,裁剪掉不必要的点提高效率
self.prune_points(prune_filter)
实现过程
- 获取初始点云的数量 n_init_points。
- 创建一个与梯度相同大小的零张量 padded_grad,并将梯度值填充到其中。
- 根据梯度的阈值条件和点云的缩放因子,生成一个选择点的掩码 selected_pts_mask。
- 将满足梯度和缩放条件的点复制 N次,并计算新的坐标、缩放、旋转和特征。
- 将新生成的点云和特征附加到原始点云中。
- 创建一个用于剪枝的过滤器prune_filter,其中包括原始点云和新生成的点云的掩码。
- 根据剪枝过滤器删除不需要的点。
重建不足的区域克隆小高斯
官方代码
def densify_and_clone(self, grads, grad_threshold, scene_extent):
# Extract points that satisfy the gradient condition
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
selected_pts_mask = torch.logical_and(selected_pts_mask,
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
new_xyz = self._xyz[selected_pts_mask]
new_features_dc = self._features_dc[selected_pts_mask]
new_features_rest = self._features_rest[selected_pts_mask]
new_opacities = self._opacity[selected_pts_mask]
new_scaling = self._scaling[selected_pts_mask]
new_rotation = self._rotation[selected_pts_mask]
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
实现过程
- 根据梯度的阈值条件和点云的缩放因子,生成一个选择点的掩码 selected_pts_mask。
- 基于选择的点生成新的点云坐标、特征、不透明度、缩放和旋转。
- 将新生成的点云和特征附加到原始点云中。
- 调用 densification_postfix 函数对点云进行后处理。
点剪枝
将不透明度小于一定阈值的点减去,将过大的也减去,类似于正则化过程。并且在迭代一定次数后,高斯会被设置为几乎透明。这样就能有控制地增加必要的高斯密度,同时剔除多余的高斯。
官方代码
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
grads = self.xyz_gradient_accum / self.denom
grads[grads.isnan()] = 0.0
# 点密集化过程
self.densify_and_clone(grads, max_grad, extent)
self.densify_and_split(grads, max_grad, extent)
prune_mask = (self.get_opacity < min_opacity).squeeze()
if max_screen_size:
big_points_vs = self.max_radii2D > max_screen_size
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
self.prune_points(prune_mask)
torch.cuda.empty_cache()
def prune_points(self, mask):
valid_points_mask = ~mask
optimizable_tensors = self._prune_optimizer(valid_points_mask)
self._xyz = optimizable_tensors["xyz"]
self._features_dc = optimizable_tensors["f_dc"]
self._features_rest = optimizable_tensors["f_rest"]
self._opacity = optimizable_tensors["opacity"]
self._scaling = optimizable_tensors["scaling"]
self._rotation = optimizable_tensors["rotation"]
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
self.denom = self.denom[valid_points_mask]
self.max_radii2D = self.max_radii2D[valid_points_mask]
实现过程
-
上述代码根据最小不透明度和最大屏幕尺寸等条件生成剪枝掩码 prune_mask,并调用 prune_points 函数进行点云的剪枝操作。
-
prune_points 函数根据给定的掩码 mask 对点云进行剪枝操作,将不需要的点从点云中删除,并更新相关的张量数据。