总体
https://github.com/nerfstudio-project/gsplat
simple_trainer_mcmc.py
2个关键点:
- 高斯状态转移(每100iter调用)
- 高斯随机过程(每1iter调用)
relocate_gs
- 对 alive gs 进行采样,被采样的 alive gs 将作为 dead gs 的转移目标。
- 对被采样的 alive gs 进行状态更新,opacities和scales属性会重新计算。
- 对 dead gs 进行状态转移。
add_new_gs
- 对 all gs 进行采样
- 被采样的 gs 进行状态更新,opacities和scales属性会重新计算
- 再把被采样的 gs 作为 copy,添加到所有 gs 中。
add_noise_to_gs
- 根据 学习率 和 opacities 控制噪声的大小
- 根据 quats 和 scales 控制噪声的分布
- 得到 delt_xyz 噪声
- 添加到 gs 的 xyz 属性上
代码AI解读
relocate_gs
(add_new_gs 类似)
@torch.no_grad()
def relocate_gs(self, min_opacity: float = 0.005) -> int:
dead_mask = torch.sigmoid(self.splats["opacities"]) <= min_opacity
dead_indices = dead_mask.nonzero(as_tuple=True)[0]
alive_indices = (~dead_mask).nonzero(as_tuple=True)[0]
num_gs = len(dead_indices)
if num_gs <= 0:
return num_gs
# Sample for new GSs
eps = torch.finfo(torch.float32).eps
probs = torch.sigmoid(self.splats["opacities"])[alive_indices]
probs = probs / (probs.sum() + eps)
sampled_idxs = torch.multinomial(probs, num_gs, replacement=True) # 进行多项式采样,num_gs 是要重新定位的粒子数量,replacement=True 表示允许重复采样。
sampled_idxs = alive_indices[sampled_idxs]
new_opacities, new_scales = compute_relocation(
opacities=torch.sigmoid(self.splats["opacities"])[sampled_idxs],
scales=torch.exp(self.splats["scales"])[sampled_idxs],
ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, # torch.bincount: 这个函数计算输入张量中每个整数值的出现次数。对于 sampled_idxs,torch.bincount 的输出将是一个包含每个索引出现次数的张量。例如,对于 sampled_idxs = [2, 1, 2, 3, 1],torch.bincount(sampled_idxs) 的输出将是 [0, 2, 2, 1]。这里,0 表示索引 0 没有出现,2 表示索引 1 出现了两次,2 表示索引 2 出现了两次,1 表示索引 3 出现了一次。
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)
self.splats["opacities"][sampled_idxs] = torch.logit(new_opacities)
self.splats["scales"][sampled_idxs] = torch.log(new_scales)
# Update splats and optimizers
for k in self.splats.keys():
self.splats[k][dead_indices] = self.splats[k][sampled_idxs]
for optimizer in self.optimizers:
for i, param_group in enumerate(optimizer.param_groups):
p = param_group["params"][0]
name = param_group["name"]
p_state = optimizer.state[p]
del optimizer.state[p]
for key in p_state.keys():
if key != "step":
p_state[key][sampled_idxs] = 0
p_new = torch.nn.Parameter(self.splats[name])
optimizer.param_groups[i]["params"] = [p_new]
optimizer.state[p_new] = p_state
self.splats[name] = p_new
torch.cuda.empty_cache()
return num_gs
compute_relocation
// Equation (9) in "3D Gaussian Splatting as Markov Chain Monte Carlo"
__global__ void compute_relocation_kernel(int N, float *opacities, float *scales,
int *ratios, float *binoms, int n_max,
float *new_opacities, float *new_scales) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= N)
return;
int n_idx = ratios[idx];
float denom_sum = 0.0f;
// compute new opacity
new_opacities[idx] = 1.0f - powf(1.0f - opacities[idx], 1.0f / n_idx);
// compute new scale
for (int i = 1; i <= n_idx; ++i) {
for (int k = 0; k <= (i - 1); ++k) {
float bin_coeff = binoms[(i - 1) * n_max + k];
float term = (pow(-1.0f, k) / sqrt(static_cast<float>(k + 1))) *
pow(new_opacities[idx], k + 1);
denom_sum += (bin_coeff * term);
}
}
float coeff = (opacities[idx] / denom_sum);
for (int i = 0; i < 3; ++i)
new_scales[idx * 3 + i] = coeff * scales[idx * 3 + i];
}
计算新的透明度(Opacity):
使用公式 new_opacities[idx]=1.0−(1.0−opacities[idx])1.0/n_idx\text{new\_opacities}[idx] = 1.0 - (1.0 - \text{opacities}[idx])^{1.0 / n\_idx}new_opacities[idx]=1.0−(1.0−opacities[idx])1.0/n_idx 来计算新的透明度。这个公式是基于论文中的公式 (9) 推导出来的。
计算新的尺度(Scale):
通过一个嵌套的循环来计算新的尺度。这个过程涉及到二项式系数(
binoms
)和一些数学运算,包括幂运算和平方根运算。具体来说,内核函数计算了一个系数
coeff
,然后用这个系数来调整原始的尺度值,得到新的尺度值。
add_noise_to_gs
@torch.no_grad()
def add_noise_to_gs(self, last_lr):
opacities = torch.sigmoid(self.splats["opacities"])
scales = torch.exp(self.splats["scales"])
actual_covariance, _ = quat_scale_to_covar_preci(
self.splats["quats"],
scales,
compute_covar=True,
compute_preci=False,
triu=False,
)
def op_sigmoid(x, k=100, x0=0.995):
return 1 / (1 + torch.exp(-k * (x - x0)))
noise = (
torch.randn_like(self.splats["means3d"])
* (op_sigmoid(1 - opacities)).unsqueeze(-1)
* cfg.noise_lr
* last_lr
)
noise = torch.bmm(actual_covariance, noise.unsqueeze(-1)).squeeze(-1)
self.splats["means3d"].add_(noise) # 只改变xyz