Spacetime Gaussian Feature Splatting for Real-Time Dynamic View Synthesis
🚩解决动态场景的视图合成,达到高分辨率逼真效果、实时渲染、体积小的目标。
💡提出 Spacetime Gaussian Feature Splatting,由三部分构成:
- 新的高斯表示:被 temporal opacity 和 parametric motion/rotation 强化的时空高斯。使得时空高斯能够捕捉静态、动态和瞬态内容。
- 引入 splatted feature rendering:用神经特征 neural features 替换球谐波 spherical harmonics。Splatted feature 解决视图和时间相关的外观建模,同时保持小尺寸。
- 利用训练误差和粗糙深度引导采样新区域内的高斯:解决难以与存在的pipeline融合(coverge)。
🚀性能:8K分辨率,轻量化版本能够在 RTX 4090 上以 60 FPS 渲染。
3D Gaussian Splatting
给定在已知相机姿态多视点的图像,3D Gaussian Splatting 通过可微光栅化(differentiable rasterization)优化各向异性的3D高斯,从而表示静态3D场景。有效的光栅化使得模型能够实时渲染高保真视图。
3D 高斯
i
i
i 和 位置
μ
i
\mu_i
μi,协方差矩阵
∑
i
\sum_i
∑i,opacity
σ
i
\sigma_i
σi,spherical harmonics
h
i
h_i
hi关联。
在任意空间点 x,最终的 3D 高斯 opacity:
α
i
=
σ
i
exp
(
−
1
2
(
x
−
μ
i
)
T
∑
i
−
1
(
x
−
μ
i
)
)
.
(
1
)
\alpha_i = \sigma_i \exp(-\frac{1}{2}(x-\mu_i)^T \sum^{-1}_{i}(x-\mu_i)). (1)
αi=σiexp(−21(x−μi)Ti∑−1(x−μi)).(1)
∑
i
\sum_i
∑i是半正定,能解耦出 scaling matrix
S
i
S_i
Si 和 rotation matrix
R
i
R_i
Ri:
∑
i
=
R
i
S
i
S
i
T
R
i
T
\sum_i=R_i S_i S^T_i R^T_i
i∑=RiSiSiTRiT
渲染流程 3D → \rightarrow → 2D:
-
3D高斯首先通过透视变换的近似值投影到2D图像空间。3D高斯投影近似为2D高斯(有center μ i 2 D \mu^{2D}_i μi2D 和 covariance ∑ i 2 D \sum^{2D}_i ∑i2D)投影。 W , K W, K W,K分别为视角变换和投影矩阵。
计算 μ i 2 D \mu^{2D}_i μi2D:
μ i 2 D = ( K ( ( W μ i ) / ( W μ i ) z ) ) 1 : 2 \mu^{2D}_i=(K((W\mu_i)/(W\mu_i)_z))_{1:2} μi2D=(K((Wμi)/(Wμi)z))1:2
计算 ∑ i 2 D \sum^{2D}_i ∑i2D:
∑ i 2 D = ( J W ∑ i W T J T ) 1 : 2 , 1 : 2 \sum^{2D}_i=(J W\sum_i W^T J^T)_{1:2,1:2} i∑2D=(JWi∑WTJT)1:2,1:2
W W W为平移旋转矩阵, J J J为投影变换矩阵 K K K的雅可比矩阵。 -
对高斯按照深度值进行排序,像素颜色通过体积渲染得到的:
I = ∑ i ∈ N c i α i 2 D ∏ j = 1 i − 1 ( 1 − α j 2 D ) I=\sum_{i\in N} c_i \alpha^{2D}_i \prod^{i-1}_{j=1}(1-\alpha^{2D}_j) I=i∈N∑ciαi2Dj=1∏i−1(1−αj2D)
α i 2 D \alpha^{2D}_i αi2D是公式(1)的2D版本, μ i , ∑ i , x \mu_i, \sum_i, x μi,∑i,x被替换为对应的像素坐标;
c i c_i ci是它是用视图方向和系数 h i h_i hi 评估 SH 后的 RGB 颜色。
方法
点云作为输入
Spacetime Gaussian
- 给 3D 高斯加入了时间。
- 引入 temporal radial basis function, 编码 temporal opacity,有效建模场景内容的出现和消失。
- 利用 time-conditioned parametric functions, 建模 3D 高斯的旋转和移动。
对于时空点
(
x
,
t
)
(x,t)
(x,t), STG 的 opacity:
α
i
(
t
)
=
σ
i
(
t
)
exp
(
−
1
2
(
x
−
μ
i
(
t
)
)
T
∑
i
−
1
(
x
−
μ
i
(
t
)
)
)
\alpha_i(t)=\sigma_i(t)\exp(-\frac{1}{2}(x-\mu_i(t))^T \sum^{-1}_{i}(x-\mu_i(t)))
αi(t)=σi(t)exp(−21(x−μi(t))Ti∑−1(x−μi(t)))
Temporal radial basis function → \rightarrow → temporal opacity
σ
i
(
t
)
=
σ
i
s
exp
(
−
s
i
τ
∣
t
−
μ
i
τ
∣
2
)
\sigma_i(t)=\sigma^{s}_i \exp(-s^{\tau}_i \vert t-\mu^{\tau}_i\vert^2)
σi(t)=σisexp(−siτ∣t−μiτ∣2)
μ
i
τ
\mu^{\tau}_i
μiτ时间中心;
s
i
τ
s^{\tau}_i
siτ时间尺度因子;
σ
i
s
\sigma^{s}_i
σis是 time-independent spatial opacity.
在渲染器中输入了 Temploral radial basis function:
render_pkg = render(viewpoint_cam, gaussians, pipe, background,
override_color=None,
basicfunction=rbfbasefunction, # temporal radial basis function
GRsetting=GRsetting, GRzer=GRzer)
渲染器中实现细节:
pc: guassianmodel
pointtimes = torch.ones((pc.get_xyz.shape[0], 1),
dtype=pc.get_xyz.dtype,
requires_grad=False, device="cuda") + 0
# 定义时间中心,缩放参数
trbfcenter = pc.get_trbfcenter
trbfscale = pc.get_trbfscale
pointopacity = pc.get_opacity # time-independent spatial opacity
trbfdistanceoffset = viewpoint_camera.timestamp * pointtimes - trbfcenter # t - u
trbfdistance = trbfdistanceoffset / torch.exp(trbfscale) # 归一化
trbfoutput = basicfunction(trbfdistance)
opacity = pointopacity * trbfoutput # - 0.5, temporal opacity
pc.trbfoutput = trbfoutput
def trbfunction(x):
return torch.exp(-1*x.pow(2))
Time-conditioned function → \rightarrow → motion & rotation
Polynomial motion trajectory
对于每个 STG,采用 time-conditioned function 建模其运动。
μ
i
(
t
)
=
∑
k
=
0
n
p
b
i
,
k
(
t
−
μ
i
τ
)
k
\mu_i(t)=\sum^{n_p}_{k=0}b_{i,k}(t-\mu^{\tau}_i)^k
μi(t)=k=0∑npbi,k(t−μiτ)k
μ
i
(
t
)
\mu_i(t)
μi(t)是STG在时刻 t 的空间位置。
{
b
i
,
k
}
k
=
0
n
p
\{b_{i,k}\}^{n_p}_{k=0}
{bi,k}k=0np是对应的多项系数,可学习。
组合 temporal radial basis function 和 time-conditioned parametric functions for polynomial motion trajectory,复杂长动作能够被多个简单动作表示。
n
q
=
3
n_q=3
nq=3 in this paper.
渲染器中实现方法:
means3D = pc.get_xyz
tforpoly = trbfdistanceoffset.detach() # t - u
means3D = means3D +
pc._motion[:, 0:3] * tforpoly +
pc._motion[:, 3:6] * tforpoly * tforpoly +
pc._motion[:, 6:9] * tforpoly * tforpoly * tforpoly
Polynomial rotation
使用 real-valued quaternion 参数化公式
∑
i
=
R
i
S
i
S
i
T
R
i
T
\sum_i=R_i S_i S^T_i R^T_i
∑i=RiSiSiTRiT中的 rotation matrix
R
i
R_i
Ri.
类似 motion trajectory,使用多项式表示四元数 quaternion:
q
i
(
t
)
=
∑
k
=
0
n
q
c
i
,
k
(
t
−
μ
i
τ
)
k
q_i(t)=\sum^{n_q}_{k=0}c_{i,k}(t-\mu^{\tau}_{i})^k
qi(t)=k=0∑nqci,k(t−μiτ)k
q
i
(
t
)
q_i(t)
qi(t) is the rotation of an STG at time
t
t
t.
{
c
i
,
k
}
k
=
0
n
q
\{c_{i,k}\}^{n_q}_{k=0}
{ci,k}k=0nq are polynomial coefficients.
n
q
=
1
n_q=1
nq=1 in this paper.
After
q
i
(
t
)
→
r
o
t
a
t
i
o
n
m
a
t
r
i
x
R
i
(
t
)
q_i(t)\rightarrow rotation\ matrix\ R_i(t)
qi(t)→rotation matrix Ri(t), covariance
∑
i
(
t
)
\sum_i(t)
∑i(t) can be obtained via
∑
i
=
R
i
S
i
S
i
T
R
i
T
\sum_i=R_i S_i S^T_i R^T_i
∑i=RiSiSiTRiT.
渲染器中实现方法:
rotations = pc.get_rotation(tforpoly) # t - u
def get_rotation(self, delta_t):
rotation = self._rotation + delta_t * self._omega
self.delta_t = delta_t
return self.rotation_activation(rotation)
Splatted Feature Rendering
To encode view- and time-dependent radiance both accurately and compactly, they change the method that store features.
The features
f
i
(
t
)
∈
R
3
f_i(t) \in \R^3
fi(t)∈R3:
f
i
(
t
)
=
[
f
i
b
a
s
e
,
f
i
d
i
r
,
(
t
−
μ
i
τ
)
f
i
t
i
m
e
]
f_i(t)=[f^{base}_i,f^{dir}_i,(t-\mu^{\tau}_i)f^{time}_i]
fi(t)=[fibase,fidir,(t−μiτ)fitime]
f
i
b
a
s
e
∈
R
3
f^{base}_i\in \R^3
fibase∈R3 base RGB color;
f
i
d
i
r
,
f
i
t
i
m
e
∈
R
3
f^{dir}_i,f^{time}_i \in \R^3
fidir,fitime∈R3 encode information related to view direction and time.
features
f
i
(
t
)
f_i(t)
fi(t)replace RGB color
c
i
c_i
ci in
I
=
∑
i
∈
N
c
i
α
i
2
D
∏
j
=
1
i
−
1
(
1
−
α
j
2
D
)
I=\sum_{i\in N} c_i \alpha^{2D}_i \prod^{i-1}_{j=1}(1-\alpha^{2D}_j)
I=∑i∈Nciαi2D∏j=1i−1(1−αj2D)
After splatting to image space, they split the splatted features at each pixel into F b a s e , F d i r , F t i m e F^{base}, F^{dir}, F^{time} Fbase,Fdir,Ftime.
渲染器中实现方法:
colors_precomp = pc.get_features(tforpoly) # 预先计算的颜色值
rendered_image, radii, depth = rasterizer(
means3D = means3D, # xyz
means2D = means2D, # screenspace points
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp)
def get_features(self, deltat):
return torch.cat((self._features_dc, deltat * self._features_t), dim=1)
The final RGB color at each pixel is obtained after going through a 2-layer MLP
Φ
\Phi
Φ:
I
=
F
b
a
s
e
+
Φ
(
F
d
i
r
,
F
t
i
m
e
,
r
)
I=F^{base}+\Phi(F^{dir},F^{time},r)
I=Fbase+Φ(Fdir,Ftime,r)
r
r
r is the view direction at the pixel and is additionally concatenated with the features as input.
rgbdecoder = Sandwich(9, 3)
rendered_image = pc.rgbdecoder(rendered_image.unsqueeze(0),
viewpoint_camera.rays, # r, the view direction at the pixel
viewpoint_camera.timestamp) # 1 , 3
rendered_image = rendered_image.squeeze(0)
class Sandwich(nn.Module):
def __init__(self, dim, outdim=3, bias=False):
super(Sandwich, self).__init__()
self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias)
self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias)
self.relu = nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, input, rays, time=None):
albedo, spec, timefeature = input.chunk(3,dim=1) # f_base, f_direction, f_time
specular = torch.cat([spec, timefeature, rays], dim=1) # 3+3 + 5
specular = self.mlp1(specular)
specular = self.relu(specular)
specular = self.mlp2(specular)
result = albedo + specular
result = self.sigmoid(result)
return result
Sampling
解决:稀疏 高斯区域 和 距离摄像机太远的区域 很难收敛高质量渲染。
实现方法
找到 errors 较大的区域 → \rightarrow → 找到区域中心像素 → \rightarrow → 规定深度范围 → \rightarrow → 找到经过中心点像素的 ray → \rightarrow → 添加新的 Gaussian 点
- 在训练 loss 稳定后,进行采样操作,保证采样的有效性;
- 选择 errors 较大的区域,采用 patch 的方法聚合 training errors;
ssimcurrent = ssim(image.detach(), gt_image.detach()).item()
if ssimcurrent < 0.88:
imageadjust = image / (torch.mean(image) + 0.01)
gtadjust = gt_image / (torch.mean(gt_image) + 0.01)
diff = torch.abs(imageadjust - gtadjust)
diff = torch.sum(diff, dim=0) # h, w
''' 取接近中间位置的 diff 作为阈值 '''
diff_sorted, _ = torch.sort(diff.reshape(-1))
numpixels = diff.shape[0] * diff.shape[1]
threshold = diff_sorted[int(numpixels*opt.emsthr)].item() # opt.emsthr = 0.6
# 标记 errors 较大像素点
outmask = diff > threshold
kh, kw = 16, 16 # kernel size
dh, dw = 16, 16 # 垂直和水平方向的 stride
# compute padding
idealh, idealw = int(image.shape[1] / dh + 1) * kw, int(image.shape[2] / dw + 1) * kw
outmask = torch.nn.functional.pad(
outmask,
(0, idealw - outmask.shape[1], 0, idealh - outmask.shape[0]),
mode='constant', value=0)
# 通过滑动窗口实现对 outmask 进行分区,得到 patches
patches = outmask.unfold(0, kh, dh).unfold(1, kw, dw)
''' 得到最终 errors 较大的区域 '''
dummypatch = torch.ones_like(patches)
# 求和 区域内 errors 的大小
patchessum = patches.sum(dim=(2,3))
patchesmusk = patchessum > kh * kh * 0.85
patchesmusk = patchesmusk.unsqueeze(2).unsqueeze(3).repeat(1,1,kh,kh).float()
patches = dummypatch * patchesmusk
- 找到符合条件区域的中心区域,并选定中心像素;
midpatch = torch.ones_like(patches)
# 将偶数的 行 和 列 设置为 0
for i in range(0, kh, 2):
for j in range(0, kw, 2):
midpatch[:,:, i, j] = 0.0
# 保留 patches 的中心
centerpatches = patches * midpatch
unfold_shape = patches.size()
patches_orig = patches.view(unfold_shape)
centerpatches_orig = centerpatches.view(unfold_shape)
output_h = unfold_shape[0] * unfold_shape[2]
output_w = unfold_shape[1] * unfold_shape[3]
patches_orig = patches_orig.permute(0, 2, 1, 3).contiguous()
centerpatches_orig = centerpatches_orig.permute(0, 2, 1, 3).contiguous()
centermask = centerpatches_orig.view(output_h, output_w).float() # H * W mask, # 1 for error, 0 for no error
# 变回 原始图像大小
centermask = centermask[:image.shape[1], :image.shape[2]] # reverse back
errormask = patches_orig.view(output_h, output_w).float() # H * W mask, # 1 for error, 0 for no error
errormask = errormask[:image.shape[1], :image.shape[2]] # reverse back
'''取中心部分'''
H, W = centermask.shape
offsetH = int(H/10)
offsetW = int(W/10)
centermask[0:offsetH, :] = 0.0
centermask[:, 0:offsetW] = 0.0
centermask[-offsetH:, :] = 0.0
centermask[:, -offsetW:] = 0.0
# errors 较大的点的索引
badindices = centermask.nonzero()
- 防止在过大的区域采样,故规定深度区域,在区域内采样新的高斯。使用 Gaussian 中心的粗糙 depth map 规定采样的深度范围;
render_pkg = render(viewpoint_cam, gaussians, pipe, background, override_color=None, basicfunction=rbfbasefunction, GRsetting=GRsetting, GRzer=GRzer)
depth = render_pkg["depth"]
diff_sorted , _ = torch.sort(depth.reshape(-1))
N = diff_sorted.shape[0]
mediandepth = int(0.7 * N)
mediandepth = diff_sorted[mediandepth]
depth = torch.where(depth>mediandepth, depth, mediandepth)
# 在(meidandepth, maxdepth]深度区域内添加新的高斯
totalNnewpoints = gaussians.addgaussians(badindices,
viewpoint_cam,
depth,
gt_image,
numperay=opt.farray,ratioend=opt.rayends,
depthmax=depthdict[viewpoint_cam.image_name],
shuffle=(opt.shuffleems != 0))
visibility_filter = torch.cat((visibility_filter, torch.zeros(totalNnewpoints).cuda(0)), dim=0)
radii = torch.cat((radii, torch.zeros(totalNnewpoints).cuda(0)), dim=0)
viewspace_point_tensor = torch.cat((viewspace_point_tensor, torch.zeros(totalNnewpoints, 3).cuda(0)), dim=0)
- 沿着有较大训练 errors 的像素射线 ray;
def addgaussians(self, baduvidx, viewpoint_cam, depthmap, gt_image, numperay=3, ratioend=2,
trbfcenter=0.5,depthmax=None,shuffle=False):
def pix2ndc(v, S):
return (v * 2.0 + 1.0) / S - 1.0
rgbs = gt_image[:, baduvidx[:,0], baduvidx[:,1]]
rgbs = rgbs.permute(1,0)
# should we add the feature dc with non zero values? direction feature
featuredc = torch.cat((rgbs, torch.zeros_like(rgbs)), dim=1)
depths = depthmap[:, baduvidx[:,0], baduvidx[:,1]]
# only use depth map > 15 .
depths = depths.permute(1,0)
# use the max local depth for the scene ?
depths = torch.ones_like(depths) * depthmax
maxx, minx = self.maxx, self.minx
# baduvidx 存储的点的坐标
u = baduvidx[:,0] # hight y
v = baduvidx[:,1] # weidth x
# 0.7 to ratiostart
ratiaolist = torch.linspace(self.raystart, ratioend, numperay)
for zscale in ratiaolist :
ndcu, ndcv = pix2ndc(u, viewpoint_cam.image_height), pix2ndc(v, viewpoint_cam.image_width)
# targetPz = depths * zscale # depth in local cameras..
if shuffle == True:
randomdepth = torch.rand_like(depths) - 0.5 # -0.5 to 0.5
# 设置 depths 左右的深度值
targetPz = (depths + depths/10*(randomdepth)) * zscale
else:
targetPz = depths*zscale # depth in local cameras..
ndcu = ndcu.unsqueeze(1)
ndcv = ndcv.unsqueeze(1)
ndccamera = torch.cat((ndcv,
ndcu,
torch.ones_like(ndcu) * (1.0),
torch.ones_like(ndcu)), dim=1) # N,4 ...
# 投影到相机坐标
localpointuv = ndccamera @ projectinverse.T
# ray direction in camera space
diretioninlocal = localpointuv / localpointuv[:,3:]
# 目标深度值 和 ray的z坐标 的比率
rate = targetPz / diretioninlocal[:, 2:3]
# 得到目标深度的点
localpoint = diretioninlocal * rate
localpoint[:, -1] = 1
# 投影到世界坐标
worldpointH = localpoint @ camera2wold.T
worldpoint = worldpointH / worldpointH[:, 3:]
# 得到 世界坐标里 ray上的目标点
xyz = worldpoint[:, :3]
# 在 (minx, maxx) 之间的点
xmask = torch.logical_and(xyz[:, 0] > minx, xyz[:, 0] < maxx )
# 整个区域
selectedmask = torch.logical_or(xmask, torch.logical_not(xmask))
# 存储目标点 (ray 上的点)
new_xyz.append(xyz[selectedmask])
new_features_dc.append(featuredc.cuda(0)[selectedmask])
selectnumpoints = torch.sum(selectedmask).item()
new_trbf_center.append(torch.rand((selectnumpoints, 1)).cuda())
assert self.trbfslinit < 1
new_trbf_scale.append(self.trbfslinit * torch.ones((selectnumpoints, 1), device="cuda"))
new_motion.append(torch.zeros((selectnumpoints, 9), device="cuda"))
new_omega.append(torch.zeros((selectnumpoints, 4), device="cuda"))
new_featuret.append(torch.zeros((selectnumpoints, 3), device="cuda"))
new_xyz = torch.cat(new_xyz, dim=0)
new_rotation = torch.zeros((new_xyz.shape[0],4), device="cuda")
new_rotation[:, 1]= 0
new_features_dc = torch.cat(new_features_dc, dim=0)
new_opacity = inverse_sigmoid(0.1 *torch.ones_like(new_xyz[:, 0:1]))
new_trbf_center = torch.cat(new_trbf_center, dim=0)
new_trbf_scale = torch.cat(new_trbf_scale, dim=0)
new_motion = torch.cat(new_motion, dim=0)
new_omega = torch.cat(new_omega, dim=0)
new_featuret = torch.cat(new_featuret, dim=0)
tmpxyz = torch.cat((new_xyz, self._xyz), dim=0)
dist2 = torch.clamp_min(distCUDA2(tmpxyz), 0.0000001)
dist2 = dist2[:new_xyz.shape[0]]
scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
scales = torch.clamp(scales, -10, 1.0)
new_scaling = scales
# 致密化新的 3D 高斯点
self.densification_postfix(new_xyz, new_features_dc, new_opacity, new_scaling, new_rotation, new_trbf_center,
new_trbf_scale, new_motion, new_omega,new_featuret)
关注标记的点
去掉渲染正确点,保留 errors 大的点,在下次迭代中着重优化:
gt_image = gt_image * errormask
image = render_pkg["render"] * errormask
torchvision.utils.save_image(gt_image,
os.path.join(pathdir, "maskedudgt" + str(iteration) + ".png"))
torchvision.utils.save_image(image,
os.path.join(pathdir, "maskedrender" + str(iteration) + ".png"))
优点
基于特征的方法比球谐波编码的方法,需要的参数少。渲染时间快。
可去掉
Φ
\Phi
Φ来加快渲染时间。(lite版本)
Loss
image 和 gt-image 做 loss 计算。
L
1
L_1
L1 和 D-SSIM