Spacetime Gaussian Feature Splatting for Real-Time Dynamic View Synthesis 笔记

Spacetime Gaussian Feature Splatting for Real-Time Dynamic View Synthesis

🚩解决动态场景的视图合成,达到高分辨率逼真效果、实时渲染、体积小的目标。

💡提出 Spacetime Gaussian Feature Splatting,由三部分构成:

  • 新的高斯表示:被 temporal opacity 和 parametric motion/rotation 强化的时空高斯。使得时空高斯能够捕捉静态、动态和瞬态内容。
  • 引入 splatted feature rendering:用神经特征 neural features 替换球谐波 spherical harmonics。Splatted feature 解决视图和时间相关的外观建模,同时保持小尺寸。
  • 利用训练误差和粗糙深度引导采样新区域内的高斯:解决难以与存在的pipeline融合(coverge)。

🚀性能:8K分辨率,轻量化版本能够在 RTX 4090 上以 60 FPS 渲染。

3D Gaussian Splatting

给定在已知相机姿态多视点的图像,3D Gaussian Splatting 通过可微光栅化(differentiable rasterization)优化各向异性的3D高斯,从而表示静态3D场景。有效的光栅化使得模型能够实时渲染高保真视图。

3D 高斯 i i i 和 位置 μ i \mu_i μi,协方差矩阵 ∑ i \sum_i i,opacity σ i \sigma_i σi,spherical harmonics h i h_i hi关联。
在任意空间点 x,最终的 3D 高斯 opacity:
α i = σ i exp ⁡ ( − 1 2 ( x − μ i ) T ∑ i − 1 ( x − μ i ) ) . ( 1 ) \alpha_i = \sigma_i \exp(-\frac{1}{2}(x-\mu_i)^T \sum^{-1}_{i}(x-\mu_i)). (1) αi=σiexp(21(xμi)Ti1(xμi)).(1)
∑ i \sum_i i是半正定,能解耦出 scaling matrix S i S_i Si 和 rotation matrix R i R_i Ri:
∑ i = R i S i S i T R i T \sum_i=R_i S_i S^T_i R^T_i i=RiSiSiTRiT

渲染流程 3D → \rightarrow 2D:

  • 3D高斯首先通过透视变换的近似值投影到2D图像空间。3D高斯投影近似为2D高斯(有center μ i 2 D \mu^{2D}_i μi2D 和 covariance ∑ i 2 D \sum^{2D}_i i2D)投影。 W , K W, K W,K分别为视角变换和投影矩阵。
    计算 μ i 2 D \mu^{2D}_i μi2D:
    μ i 2 D = ( K ( ( W μ i ) / ( W μ i ) z ) ) 1 : 2 \mu^{2D}_i=(K((W\mu_i)/(W\mu_i)_z))_{1:2} μi2D=(K((Wμi)/(Wμi)z))1:2
    计算 ∑ i 2 D \sum^{2D}_i i2D:
    ∑ i 2 D = ( J W ∑ i W T J T ) 1 : 2 , 1 : 2 \sum^{2D}_i=(J W\sum_i W^T J^T)_{1:2,1:2} i2D=(JWiWTJT)1:2,1:2
    W W W为平移旋转矩阵, J J J为投影变换矩阵 K K K的雅可比矩阵。

  • 对高斯按照深度值进行排序,像素颜色通过体积渲染得到的:
    I = ∑ i ∈ N c i α i 2 D ∏ j = 1 i − 1 ( 1 − α j 2 D ) I=\sum_{i\in N} c_i \alpha^{2D}_i \prod^{i-1}_{j=1}(1-\alpha^{2D}_j) I=iNciαi2Dj=1i1(1αj2D)
    α i 2 D \alpha^{2D}_i αi2D是公式(1)的2D版本, μ i , ∑ i , x \mu_i, \sum_i, x μi,i,x被替换为对应的像素坐标;
    c i c_i ci是它是用视图方向和系数 h i h_i hi 评估 SH 后的 RGB 颜色。

方法

点云作为输入
Spacetime Gaussian

  • 给 3D 高斯加入了时间。
  • 引入 temporal radial basis function, 编码 temporal opacity,有效建模场景内容的出现和消失。
  • 利用 time-conditioned parametric functions, 建模 3D 高斯的旋转和移动。

对于时空点 ( x , t ) (x,t) (x,t), STG 的 opacity:
α i ( t ) = σ i ( t ) exp ⁡ ( − 1 2 ( x − μ i ( t ) ) T ∑ i − 1 ( x − μ i ( t ) ) ) \alpha_i(t)=\sigma_i(t)\exp(-\frac{1}{2}(x-\mu_i(t))^T \sum^{-1}_{i}(x-\mu_i(t))) αi(t)=σi(t)exp(21(xμi(t))Ti1(xμi(t)))

Temporal radial basis function → \rightarrow temporal opacity

σ i ( t ) = σ i s exp ⁡ ( − s i τ ∣ t − μ i τ ∣ 2 ) \sigma_i(t)=\sigma^{s}_i \exp(-s^{\tau}_i \vert t-\mu^{\tau}_i\vert^2) σi(t)=σisexp(siτtμiτ2)
μ i τ \mu^{\tau}_i μiτ时间中心; s i τ s^{\tau}_i siτ时间尺度因子; σ i s \sigma^{s}_i σis是 time-independent spatial opacity.
在渲染器中输入了 Temploral radial basis function:

render_pkg = render(viewpoint_cam, gaussians, pipe, background,  
                    override_color=None,  
                    basicfunction=rbfbasefunction,     # temporal radial basis function
                    GRsetting=GRsetting, GRzer=GRzer)

渲染器中实现细节:

pc: guassianmodel
pointtimes = torch.ones((pc.get_xyz.shape[0], 1), 
                         dtype=pc.get_xyz.dtype, 
                         requires_grad=False, device="cuda") + 0

# 定义时间中心,缩放参数
trbfcenter = pc.get_trbfcenter
trbfscale = pc.get_trbfscale
pointopacity = pc.get_opacity    # time-independent spatial opacity

trbfdistanceoffset = viewpoint_camera.timestamp * pointtimes - trbfcenter # t - u
trbfdistance =  trbfdistanceoffset / torch.exp(trbfscale)     # 归一化
trbfoutput = basicfunction(trbfdistance)

opacity = pointopacity * trbfoutput  # - 0.5, temporal opacity
pc.trbfoutput = trbfoutput

def trbfunction(x): 
    return torch.exp(-1*x.pow(2))

Time-conditioned function → \rightarrow motion & rotation

Polynomial motion trajectory

对于每个 STG,采用 time-conditioned function 建模其运动。
μ i ( t ) = ∑ k = 0 n p b i , k ( t − μ i τ ) k \mu_i(t)=\sum^{n_p}_{k=0}b_{i,k}(t-\mu^{\tau}_i)^k μi(t)=k=0npbi,k(tμiτ)k
μ i ( t ) \mu_i(t) μi(t)是STG在时刻 t 的空间位置。 { b i , k } k = 0 n p \{b_{i,k}\}^{n_p}_{k=0} {bi,k}k=0np是对应的多项系数,可学习。
组合 temporal radial basis function 和 time-conditioned parametric functions for polynomial motion trajectory,复杂长动作能够被多个简单动作表示。
n q = 3 n_q=3 nq=3 in this paper.

渲染器中实现方法:

means3D = pc.get_xyz
tforpoly = trbfdistanceoffset.detach()    # t - u
means3D = means3D +  
          pc._motion[:, 0:3] * tforpoly + 
          pc._motion[:, 3:6] * tforpoly * tforpoly + 
          pc._motion[:, 6:9] * tforpoly * tforpoly * tforpoly

Polynomial rotation

使用 real-valued quaternion 参数化公式 ∑ i = R i S i S i T R i T \sum_i=R_i S_i S^T_i R^T_i i=RiSiSiTRiT中的 rotation matrix R i R_i Ri.
类似 motion trajectory,使用多项式表示四元数 quaternion:
q i ( t ) = ∑ k = 0 n q c i , k ( t − μ i τ ) k q_i(t)=\sum^{n_q}_{k=0}c_{i,k}(t-\mu^{\tau}_{i})^k qi(t)=k=0nqci,k(tμiτ)k
q i ( t ) q_i(t) qi(t) is the rotation of an STG at time t t t. { c i , k } k = 0 n q \{c_{i,k}\}^{n_q}_{k=0} {ci,k}k=0nq are polynomial coefficients.
n q = 1 n_q=1 nq=1 in this paper.
After q i ( t ) → r o t a t i o n   m a t r i x   R i ( t ) q_i(t)\rightarrow rotation\ matrix\ R_i(t) qi(t)rotation matrix Ri(t), covariance ∑ i ( t ) \sum_i(t) i(t) can be obtained via ∑ i = R i S i S i T R i T \sum_i=R_i S_i S^T_i R^T_i i=RiSiSiTRiT.

渲染器中实现方法:

rotations = pc.get_rotation(tforpoly)    # t - u

def get_rotation(self, delta_t):
    rotation =  self._rotation + delta_t * self._omega
    self.delta_t = delta_t
    return self.rotation_activation(rotation)

Splatted Feature Rendering

To encode view- and time-dependent radiance both accurately and compactly, they change the method that store features.
The features f i ( t ) ∈ R 3 f_i(t) \in \R^3 fi(t)R3:
f i ( t ) = [ f i b a s e , f i d i r , ( t − μ i τ ) f i t i m e ] f_i(t)=[f^{base}_i,f^{dir}_i,(t-\mu^{\tau}_i)f^{time}_i] fi(t)=[fibase,fidir,(tμiτ)fitime]
f i b a s e ∈ R 3 f^{base}_i\in \R^3 fibaseR3 base RGB color; f i d i r , f i t i m e ∈ R 3 f^{dir}_i,f^{time}_i \in \R^3 fidir,fitimeR3 encode information related to view direction and time.
features f i ( t ) f_i(t) fi(t)replace RGB color c i c_i ci in I = ∑ i ∈ N c i α i 2 D ∏ j = 1 i − 1 ( 1 − α j 2 D ) I=\sum_{i\in N} c_i \alpha^{2D}_i \prod^{i-1}_{j=1}(1-\alpha^{2D}_j) I=iNciαi2Dj=1i1(1αj2D)

After splatting to image space, they split the splatted features at each pixel into F b a s e , F d i r , F t i m e F^{base}, F^{dir}, F^{time} Fbase,Fdir,Ftime.

渲染器中实现方法:

colors_precomp = pc.get_features(tforpoly)  # 预先计算的颜色值
rendered_image, radii, depth = rasterizer(
                                            means3D = means3D,    # xyz
                                            means2D = means2D,    # screenspace points
                                            shs = shs,
                                            colors_precomp = colors_precomp,
                                            opacities = opacity,
                                            scales = scales,
                                            rotations = rotations,
                                            cov3D_precomp = cov3D_precomp)


def get_features(self, deltat):
    return torch.cat((self._features_dc, deltat * self._features_t), dim=1)

The final RGB color at each pixel is obtained after going through a 2-layer MLP Φ \Phi Φ:
I = F b a s e + Φ ( F d i r , F t i m e , r ) I=F^{base}+\Phi(F^{dir},F^{time},r) I=Fbase+Φ(Fdir,Ftime,r)
r r r is the view direction at the pixel and is additionally concatenated with the features as input.

rgbdecoder = Sandwich(9, 3)

rendered_image = pc.rgbdecoder(rendered_image.unsqueeze(0), 
                               viewpoint_camera.rays, # r, the view direction at the pixel
                               viewpoint_camera.timestamp) # 1 , 3
rendered_image = rendered_image.squeeze(0)


class Sandwich(nn.Module):
    def __init__(self, dim, outdim=3, bias=False):
        super(Sandwich, self).__init__()
        
        self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias)
        self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias)
        self.relu = nn.ReLU()

        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, input, rays, time=None):
        albedo, spec, timefeature = input.chunk(3,dim=1)    # f_base, f_direction, f_time
        specular = torch.cat([spec, timefeature, rays], dim=1)  # 3+3 + 5
        specular = self.mlp1(specular)
        specular = self.relu(specular)
        specular = self.mlp2(specular)

        result = albedo + specular
        result = self.sigmoid(result) 
        return result

Sampling

解决:稀疏 高斯区域 和 距离摄像机太远的区域 很难收敛高质量渲染。

实现方法

找到 errors 较大的区域 → \rightarrow 找到区域中心像素 → \rightarrow 规定深度范围 → \rightarrow 找到经过中心点像素的 ray → \rightarrow 添加新的 Gaussian 点

  • 在训练 loss 稳定后,进行采样操作,保证采样的有效性;
  • 选择 errors 较大的区域,采用 patch 的方法聚合 training errors;
ssimcurrent = ssim(image.detach(), gt_image.detach()).item()
if ssimcurrent < 0.88:
    imageadjust = image / (torch.mean(image) + 0.01)
    gtadjust = gt_image / (torch.mean(gt_image) + 0.01)
    diff = torch.abs(imageadjust - gtadjust)
    diff = torch.sum(diff, dim=0) # h, w
    
    ''' 取接近中间位置的 diff 作为阈值 '''
    diff_sorted, _ = torch.sort(diff.reshape(-1)) 
    numpixels = diff.shape[0] * diff.shape[1]
    threshold = diff_sorted[int(numpixels*opt.emsthr)].item()    # opt.emsthr = 0.6
    
    # 标记 errors 较大像素点
    outmask = diff > threshold    
    kh, kw = 16, 16 # kernel size
    dh, dw = 16, 16 # 垂直和水平方向的 stride
    # compute padding  
    idealh, idealw = int(image.shape[1] / dh  + 1) * kw, int(image.shape[2] / dw + 1) * kw 
    outmask = torch.nn.functional.pad(
                         outmask, 
                         (0, idealw - outmask.shape[1], 0, idealh - outmask.shape[0]), 
                         mode='constant', value=0)
    
    # 通过滑动窗口实现对 outmask 进行分区,得到 patches                 
    patches = outmask.unfold(0, kh, dh).unfold(1, kw, dw)  
    
    ''' 得到最终 errors 较大的区域 ''' 
    dummypatch = torch.ones_like(patches)
    # 求和 区域内 errors 的大小
    patchessum = patches.sum(dim=(2,3))
    patchesmusk = patchessum  >  kh * kh * 0.85
    patchesmusk = patchesmusk.unsqueeze(2).unsqueeze(3).repeat(1,1,kh,kh).float()
    patches = dummypatch * patchesmusk 
  • 找到符合条件区域的中心区域,并选定中心像素;
    midpatch = torch.ones_like(patches)
    
    # 将偶数的 行 和 列 设置为 0
    for i in range(0, kh,  2):
        for j in range(0, kw, 2):
            midpatch[:,:, i, j] = 0.0  
    
    # 保留 patches 的中心
    centerpatches = patches * midpatch
    
    unfold_shape = patches.size()
    patches_orig = patches.view(unfold_shape)
    centerpatches_orig = centerpatches.view(unfold_shape)
    
    output_h = unfold_shape[0] * unfold_shape[2]
    output_w = unfold_shape[1] * unfold_shape[3]
    patches_orig = patches_orig.permute(0, 2, 1, 3).contiguous()
    centerpatches_orig = centerpatches_orig.permute(0, 2, 1, 3).contiguous()
    centermask = centerpatches_orig.view(output_h, output_w).float() # H * W  mask, # 1 for error, 0 for no error
    # 变回 原始图像大小
    centermask = centermask[:image.shape[1], :image.shape[2]] # reverse back
    
    errormask = patches_orig.view(output_h, output_w).float() # H * W  mask, # 1 for error, 0 for no error
    errormask = errormask[:image.shape[1], :image.shape[2]] # reverse back
    
    '''取中心部分'''
    H, W = centermask.shape
    
    offsetH = int(H/10)
    offsetW = int(W/10)
    
    centermask[0:offsetH, :] = 0.0
    centermask[:, 0:offsetW] = 0.0
    
    centermask[-offsetH:, :] = 0.0
    centermask[:, -offsetW:] = 0.0
    
    # errors 较大的点的索引
    badindices = centermask.nonzero()
  • 防止在过大的区域采样,故规定深度区域,在区域内采样新的高斯。使用 Gaussian 中心的粗糙 depth map 规定采样的深度范围;
render_pkg = render(viewpoint_cam, gaussians, pipe, background,  override_color=None,  basicfunction=rbfbasefunction, GRsetting=GRsetting, GRzer=GRzer)

    depth = render_pkg["depth"]
                         
    diff_sorted , _ = torch.sort(depth.reshape(-1)) 
    N = diff_sorted.shape[0]
    mediandepth = int(0.7 * N)
    mediandepth = diff_sorted[mediandepth]
    depth = torch.where(depth>mediandepth, depth, mediandepth) 
    
    # 在(meidandepth, maxdepth]深度区域内添加新的高斯
    totalNnewpoints = gaussians.addgaussians(badindices, 
                                             viewpoint_cam, 
                                             depth, 
                                             gt_image, 
                                             numperay=opt.farray,ratioend=opt.rayends,  
                                             depthmax=depthdict[viewpoint_cam.image_name], 
                                             shuffle=(opt.shuffleems != 0))     
     
     visibility_filter = torch.cat((visibility_filter, torch.zeros(totalNnewpoints).cuda(0)), dim=0)
     radii = torch.cat((radii, torch.zeros(totalNnewpoints).cuda(0)), dim=0)
     viewspace_point_tensor = torch.cat((viewspace_point_tensor, torch.zeros(totalNnewpoints, 3).cuda(0)), dim=0)    
  • 沿着有较大训练 errors 的像素射线 ray;
    def addgaussians(self, baduvidx, viewpoint_cam, depthmap, gt_image, numperay=3, ratioend=2, 
    				 trbfcenter=0.5,depthmax=None,shuffle=False):
        def pix2ndc(v, S):
            return (v * 2.0 + 1.0) / S - 1.0
        
        rgbs = gt_image[:, baduvidx[:,0], baduvidx[:,1]]
        rgbs = rgbs.permute(1,0)
        # should we add the feature dc with non zero values? direction feature
        featuredc = torch.cat((rgbs, torch.zeros_like(rgbs)), dim=1)

        depths = depthmap[:, baduvidx[:,0], baduvidx[:,1]]
        # only use depth map > 15 .
        depths = depths.permute(1,0) 
        
        # use the max local depth for the scene ?
        depths = torch.ones_like(depths) * depthmax 
        
        maxx, minx = self.maxx, self.minx 
        
        # baduvidx 存储的点的坐标
        u = baduvidx[:,0] # hight y
        v = baduvidx[:,1] # weidth  x 
        
        # 0.7 to ratiostart
        ratiaolist = torch.linspace(self.raystart, ratioend, numperay) 
        for zscale in ratiaolist :
            ndcu, ndcv = pix2ndc(u, viewpoint_cam.image_height), pix2ndc(v, viewpoint_cam.image_width)
            # targetPz = depths * zscale # depth in local cameras..
            if shuffle == True:
                randomdepth = torch.rand_like(depths) - 0.5 # -0.5 to 0.5
                # 设置 depths 左右的深度值
                targetPz = (depths + depths/10*(randomdepth)) * zscale 
            else:
                targetPz = depths*zscale # depth in local cameras..
            
            ndcu = ndcu.unsqueeze(1)
            ndcv = ndcv.unsqueeze(1)
 
            ndccamera = torch.cat((ndcv, 
                                   ndcu,   
                                   torch.ones_like(ndcu) * (1.0), 
                                   torch.ones_like(ndcu)), dim=1) # N,4 ...
            # 投影到相机坐标
            localpointuv = ndccamera @ projectinverse.T 
            
            # ray direction in camera space 
            diretioninlocal = localpointuv / localpointuv[:,3:] 

            # 目标深度值 和 ray的z坐标 的比率
            rate = targetPz / diretioninlocal[:, 2:3]
            
            # 得到目标深度的点
            localpoint = diretioninlocal * rate
            localpoint[:, -1] = 1
            
            # 投影到世界坐标
            worldpointH = localpoint @ camera2wold.T   
            worldpoint = worldpointH / worldpointH[:, 3:]
            
            # 得到 世界坐标里 ray上的目标点
            xyz = worldpoint[:, :3] 
            
            # 在 (minx, maxx) 之间的点
            xmask = torch.logical_and(xyz[:, 0] > minx, xyz[:, 0] < maxx )
            # 整个区域
            selectedmask = torch.logical_or(xmask, torch.logical_not(xmask))
            # 存储目标点 (ray 上的点)
            new_xyz.append(xyz[selectedmask]) 
            new_features_dc.append(featuredc.cuda(0)[selectedmask])
            
            selectnumpoints = torch.sum(selectedmask).item()
            new_trbf_center.append(torch.rand((selectnumpoints, 1)).cuda())

            assert self.trbfslinit < 1 
            new_trbf_scale.append(self.trbfslinit * torch.ones((selectnumpoints, 1), device="cuda"))
            new_motion.append(torch.zeros((selectnumpoints, 9), device="cuda")) 
            new_omega.append(torch.zeros((selectnumpoints, 4), device="cuda"))
            new_featuret.append(torch.zeros((selectnumpoints, 3), device="cuda"))
        
        new_xyz = torch.cat(new_xyz, dim=0)
        new_rotation = torch.zeros((new_xyz.shape[0],4), device="cuda")
        new_rotation[:, 1]= 0
        
        new_features_dc = torch.cat(new_features_dc, dim=0)
        new_opacity = inverse_sigmoid(0.1 *torch.ones_like(new_xyz[:, 0:1]))
        new_trbf_center = torch.cat(new_trbf_center, dim=0)
        new_trbf_scale = torch.cat(new_trbf_scale, dim=0)
        new_motion = torch.cat(new_motion, dim=0)
        new_omega = torch.cat(new_omega, dim=0)
        new_featuret = torch.cat(new_featuret, dim=0)

        tmpxyz = torch.cat((new_xyz, self._xyz), dim=0)
        dist2 = torch.clamp_min(distCUDA2(tmpxyz), 0.0000001)
        dist2 = dist2[:new_xyz.shape[0]]
        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
        scales = torch.clamp(scales, -10, 1.0)
        new_scaling = scales 
        
        # 致密化新的 3D 高斯点
        self.densification_postfix(new_xyz, new_features_dc, new_opacity, new_scaling, new_rotation, new_trbf_center, 
        						   new_trbf_scale, new_motion, new_omega,new_featuret)

关注标记的点
去掉渲染正确点,保留 errors 大的点,在下次迭代中着重优化:

    gt_image = gt_image * errormask
    image = render_pkg["render"] * errormask
    
    torchvision.utils.save_image(gt_image, 
                                 os.path.join(pathdir,  "maskedudgt" + str(iteration) + ".png"))
    torchvision.utils.save_image(image, 
                                 os.path.join(pathdir,  "maskedrender" + str(iteration) + ".png"))

优点

基于特征的方法比球谐波编码的方法,需要的参数少。渲染时间快。
可去掉 Φ \Phi Φ来加快渲染时间。(lite版本)

Loss

image 和 gt-image 做 loss 计算。
L 1 L_1 L1 和 D-SSIM

  • 45
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值