【三维重建】【深度学习】NeuS代码Pytorch实现--训练阶段代码解析(中)

【三维重建】【深度学习】NeuS代码Pytorch实现–训练阶段代码解析(中)

论文提出了一种新颖的神经表面重建方法,称为NeuS,用于从2D图像输入以高保真度重建对象和场景。在NeuS中建议将曲面表示为有符号距离函数(SDF)的零级集,并开发一种新的体绘制方法来训练神经SDF表示,因此即使没有掩模监督,也可以实现更准确的表面重建。NeuS在高质量的表面重建方面的性能优于现有技术,特别是对于具有复杂结构和自遮挡的对象和场景。本篇博文将根据代码执行流程解析训练阶段具体的功能模块代码。



前言

在详细解析NeuS网络之前,首要任务是搭建NeuS【win10下参考教程】所需的运行环境,并完成模型的训练和测试,展开后续工作才有意义。
本博文继续对NeuS训练阶段涉及的部分功能代码模块进行解析,其他代码模块后续的博文将会陆续讲解。

博主将各功能模块的代码在不同的博文中进行了详细的解析,点击【win10下参考教程】,博文的目录链接放在前言部分。


SDFNetwork网络

在exp_runner.py文件内class Runner的__init__函数中完成初始化,并将其传递给NeuS,以方便后续SDF网络的使用。

# sdf网络
self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device)
# 初始化neus神经网络
self.renderer = NeuSRenderer(self.nerf_outside,
                             self.sdf_network,
                             self.deviation_network,
                             self.color_network,
                             **self.conf['model.neus_renderer'])

在models/renderer.py文件内的render函数上调用。

# p(t) = {o + tv|t ≥ 0}
# 世界坐标系下采样点位置=光心+单位方向向量*采样点
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]    # [batch_size,n_samples,3]
# 获取等势面
sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)  # [batch_size,n_samples,1]

世界坐标系下采样点位置如下图所示。

注意:p(t) = {o + tv|t ≥ 0}是论文提到的,但是博主运行代码发现t也可能小于0(直接print打印),也就是说采样点在光心的后面(负方向延长线上),可能后续会因为训练而优化吧。

SDFNetwork网络初始化

SDFNetwork的定义在models/fields.py文件内。

def __init__(self,
             d_in,          # 输入channel
             d_out,         # 输出channnel
             d_hidden,      # 隐藏层channel
             n_layers,      # 网络层数
             skip_in=(4,),  # 网络中间插入新输入的位置
             multires=0,    # 位置编码长度
             bias=0.5,      # 用于网络层偏置的初始化赋值
             scale=1,       # 缩放比例
             geometric_init=True,   # 指定初始化形式
             weight_norm=True,      # 权重归一化
             inside_outside=False):
    super(SDFNetwork, self).__init__()

    # 各层的channel
    dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]

    # 位置编码
    self.embed_fn_fine = None
    if multires > 0:
        # 编码函数 编码输出的channel
        embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
        self.embed_fn_fine = embed_fn
        dims[0] = input_ch

    # 网络深度
    self.num_layers = len(dims)
    # 网络插入新输入的位置
    self.skip_in = skip_in
    # 缩放比例
    self.scale = scale

    for l in range(0, self.num_layers - 1):
        if l + 1 in self.skip_in:
            out_dim = dims[l + 1] - dims[0]
        else:
            out_dim = dims[l + 1]

        lin = nn.Linear(dims[l], out_dim)

        # 对网络模型权重是否进行手工初始化,否则随机初始化
        if geometric_init:
            if l == self.num_layers - 2:    # 倒数第二层
                if not inside_outside:
                    # 权重满足正态分布
                    torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
                    # 偏置设置为常数
                    torch.nn.init.constant_(lin.bias, -bias)
                else:
                    torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
                    torch.nn.init.constant_(lin.bias, bias)
            elif multires > 0 and l == 0:   # 第一层
                # 偏置设置为常数
                torch.nn.init.constant_(lin.bias, 0.0)
                # 前三个channel权重设置为常数(3D坐标输入),weight的shape是[C_out,C_in]
                torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
                # 其余输入channel权重满足正态分布
                torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
            elif multires > 0 and l in self.skip_in:    # 中间有薪输入的网络层
                # 偏置设置为常数
                torch.nn.init.constant_(lin.bias, 0.0)
                # 其余输入channel权重满足正态分布
                torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
                # 因为拼接,中间某三个channel权重设置为常数(3D坐标输入)
                torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
            else:   # 其他网络层
                # 偏置设置为常数
                torch.nn.init.constant_(lin.bias, 0.0)
                # 权重满足正态分布
                torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
        if weight_norm:
            # 权重归一化处理
            lin = nn.utils.weight_norm(lin)
        # 全连接层命名
        setattr(self, "lin" + str(l), lin)
    # 激活函数 log(1+e^x)
    self.activation = nn.Softplus(beta=100)

SDFNetwork网络权重w和偏置b初始化示意图。

橙色表示正态分布方式初始化:不同颜色深度是因为初始化时正太分布的均值和方差不同。
灰色表示常数方式初始化:不同颜色深度是因为初始化时常数的值不同。
有俩个地方比较特殊,就是在第一层和第四层因为输入中存在原始3D坐标信息,与3D坐标相关联的channel的权重是用常数方式初始化的。

SDFNetwork网络结构

def forward(self, inputs):
    # 这里输入的是3D坐标,可以坐标的缩放
    inputs = inputs * self.scale
    # 对3D坐标做位置编码
    if self.embed_fn_fine is not None:
        inputs = self.embed_fn_fine(inputs)
    x = inputs

    for l in range(0, self.num_layers - 1):
        lin = getattr(self, "lin" + str(l))
        # 中间有新的输入,新收入与上层网络的输出做拼接
        if l in self.skip_in:
            x = torch.cat([x, inputs], 1) / np.sqrt(2)
        x = lin(x)
        # 不是最后一层网络,需要经过激活层
        if l < self.num_layers - 2:
            x = self.activation(x)
    # 1:表示sdf的值,因为3D坐标进行缩放,因此等势面值也需要缩放
    # 256:理解为隐藏特征
    return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1)		# [batch_size,1+256]

SDFNetwork网络结构及其执行流程如下图所示。

γ(x)表示对3D坐标点做位置编码,作为网络的输入。
输出是1+256:1表示sdff的值,表示等势面;256表示隐藏的中间特征。

SDFNetwork网络的使用

SDFNetwork定义了其他函数满足不同的任务需求。

  • 只获取sdf值(等势面)
def sdf(self, x):
   # 只输出sdf值
   return self.forward(x)[:, :1]
  • 获取sdf值和隐藏特征
def sdf_hidden_appearance(self, x):
    # 输出sdf+隐藏特征
    return self.forward(x)
  • 获取与sdf相关的梯度
def gradient(self, x):
    # sdf相关的梯度信息
    x.requires_grad_(True)
    y = self.sdf(x)
    d_output = torch.ones_like(y, requires_grad=False, device=y.device)
    gradients = torch.autograd.grad(
        outputs=y,
        inputs=x,
        grad_outputs=d_output,
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0]
    return gradients.unsqueeze(1)

位置编码

在组成NeuS神经网络模型的各个子模型中,会对输入的3D坐标进行位置编码,用get_embedder为调用函数提供位置编码器。在models/embedder.py文件内。

def get_embedder(multires, input_dims=3):
    embed_kwargs = {
        'include_input': True,          # 是否在编码channel中包含原始输入channel
        'input_dims': input_dims,       # 输入需要编码的channel
        'max_freq_log2': multires-1,    # 最大编码频率
        'num_freqs': multires,          # 编码长度
        'log_sampling': True,           # (对数采样)采样方式选择
        'periodic_fns': [torch.sin, torch.cos],     # 周期函数
    }
    # 初始化编码器
    embedder_obj = Embedder(**embed_kwargs)
    # 定义编码器函数
    def embed(x, eo=embedder_obj): return eo.embed(x)
    # 返回编码器函数和编码输出channel
    return embed, embedder_obj.out_dim

Embedder类作为编码器,根据指定编码长度对输入的3D坐标(x,y,z)进行编码。

class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):

        embed_fns = []
        # 输入需要编码的channel
        d = self.kwargs['input_dims']
        out_dim = 0

        # 在编码channel中包含原始输入channel
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            # 统计编码输出channel
            out_dim += d
        # 最大编码频率
        max_freq = self.kwargs['max_freq_log2']
        # 编码长度
        N_freqs = self.kwargs['num_freqs']

        # 取N_freqs个系数,后续用于在对应的周期函数上取点
        # 采样方式不同,采样区间不同
        if self.kwargs['log_sampling']:
            freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)

        # 对输入的3D坐标,根据freq系数在对应的周期函数上取点
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn = p_fn, freq = freq: p_fn(x * freq))
                out_dim += d

        # 周期函数取点位置集合
        self.embed_fns = embed_fns
        # 编码输出channel
        self.out_dim = out_dim

    def embed(self, inputs):
        # torch格式:对输入的channel进行位置编码
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

根据N_freqs个等差间隔的系数,3D坐标(x,y,z)在cos和sin周期函数上共取3×2×N_freqs个编码值,作为最终的位置编码,也可以额外附加原始的3D坐标(x,y,z)到位置编码(6×N_freqs+3)。以3D坐标x坐标为例,编码过程如下图所示。

如上图,假设N_freqs是14 则x坐标获得的编码为2×14+1=29,(x,y,z)三坐标则为87。

光线rays上进行前景细采样

在models/renderer.py文件的render函数内。

for i in range(self.up_sample_steps):
	# 通过多次切香肠让精采样点不断逼近sdf=0的等势面(物体表面)
    new_z_vals = self.up_sample(rays_o,     # 光心
                                rays_d,     # 单位方向向量
                                z_vals,     # 采样点
                                sdf,        # sdf值
                                self.n_importance // self.up_sample_steps,
                                64 * 2**i)
    z_vals, sdf = self.cat_z_vals(rays_o,
                                  rays_d,
                                  z_vals,
                                  new_z_vals,
                                  sdf,
                                  last=(i + 1 == self.up_sample_steps))

up_sample

up_sample函数代码比较简单,但内容比较丰富,理解存在难度(博主个人觉得),比较难理清每行代码乃至每个变量表达的含义和目的,因此博主将函数代码拆分成几段分别讲解,懂得可以快速过。

  • 有效采样段(inside_sphere)
batch_size, n_samples = z_vals.shape
# 世界坐标系下粗采样点位置
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]  # [batch_size, n_samples, 3]
# 原点到世界坐标系下粗采样点的距离
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)         # [batch_size, n_samples]
# 世界坐标系下相邻俩个粗采样点作为一个采样段
# 采样段的俩端有一个端点在半径为1的球体内,则采样端有效
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)        # [batch_size, n_samples-1]

sdf = sdf.reshape(batch_size, n_samples)                # [batch_size, n_samples]
# 世界坐标系下近点集的sdf和世界坐标系下远点集的sdf
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]    # [batch_size, n_samples-1] [batch_size, n_samples-1]

# 近点集和远点集
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]    # [batch_size, n_samples-1] [batch_size, n_samples-1]

# 粗采样点相邻俩点的sdf的均值
mid_sdf = (prev_sdf + next_sdf) * 0.5       # [batch_size, n_samples-1]

该段代码对应的三维立体图和二维坐标图的表示。

图中红色粗采样点(世界坐标系下)在球体外,绿色在求体内,黄色点是计算出的前后俩采样点的sdf均值(mid_sdf);红色采样段是false(无效),绿色采样段是true(有效)。

inside_sphere是有效采样段是博主个人理解上的直观叫法,它后续是与概率密度相关联的。采样段是否有效在于采样段的端点,即构成采样段的相邻俩个的粗采样点(世界坐标系下)是否在球体内部。

  • 采样段的sdf密度
# sdf密度:相邻两个采样点的sdf值与采样点之间的距离dist的比值
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)     # [batch_size, n_samples-1]

prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)   # [batch_size, n_samples-1]
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)      # [batch_size, n_samples-1,2]

# sdf密度选择最小值
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)       # [batch_size, n_samples-1]

# 选择有效采样段内的sdf密度,且sdf密度的区间限定为[-1000,0]
# 部分sdf大于0,但采样段是有效段,因此sdf变为0
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere       # [batch_size, n_samples-1]

该段代码对应的二维坐标图的表示。

左图中prev_cos_val 和cos_val是未执行torch.stack之前的;右图中min_cos_val是经过torch.min和cos_val.clip后的cos_val。这里为了方便理解。

  • 采样段的概率密度因子
# 世界坐标系下粗采样点的距离
dist = (next_z_vals - prev_z_vals)
# 每段采样段前半段的sdf总值
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5          # [batch_size, n_samples-1]
# 每段采样段后半段的sdf总值
next_esti_sdf = mid_sdf + cos_val * dist * 0.5          # [batch_size, n_samples-1]

# 每段采样段前半段的cdf值
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)         # [batch_size, n_samples-1]
# 每段采样段后半段的cdf值
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)         # [batch_size, n_samples-1]
# 每段采样段的概率密度因子
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)    # [batch_size, n_samples-1]

该段代码对应的二维坐标图的表示。

alpha博主个人理解为每段采样段的概率密度因子:当采样段无效时,alpha等于0;采样段有效时,alpha大于0。alpha由0变为大于0的值则刚好是采样段穿入球体时,反之则为采样段刚好是穿出球体时。

  • 采样段的权重
# 每段采样段的权重
weights = alpha * torch.cumprod(
    torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]    # [batch_size, n_samples-1]

采样段权重=概率密度因子alpha×权重系数,当前采样段权重系数为(1-前面的所有采样段的权重)。
穿过球体表面的采样段权重通常最大:概率密度因子alpha×权重系数1。当然,无效采样段权重只能是0.

源码用torch.cumprod是累积乘法巧妙的实现。

sample_pdf

# 这段代码的实现来自nerf
# 防止因权重为0出现nans
weights = weights + 1e-5  # prevent nans
# 采样段概率密度:各采样段权重在所有采样段权重总和的占比
pdf = weights / torch.sum(weights, -1, keepdim=True)    # [batch_size, n-1]
# 累加概率密度,范围0~1
cdf = torch.cumsum(pdf, -1)     # [batch_size, n-1]
# 累加概率密度前新添加了0概率密度
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # [batch_size, n]
# Take uniform samples
if det:
    # 可以理解为n_samples个细采样点的cdf值,在(0.5 / n_samples,1. - 0.5 / n)范围内等差取样
    # 目前只有细采样点的cdf值, 这不是在光线rays上的位置
    u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples) # [n_samples]
    u = u.expand(list(cdf.shape[:-1]) + [n_samples])     # [batch_size, n_samples]
else:
    u = torch.rand(list(cdf.shape[:-1]) + [n_samples])

# Invert CDF
u = u.contiguous()
# 细采样点的cdf值与采样段的cdf比较,并将采样点分配到对应的采样段(序号)
# inds就是采样段的序号,这里的序号也可以对应组成采样段的后粗采样点(后端点)的序号
inds = torch.searchsorted(cdf, u, right=True)
# max:最小不能小于第一个采样点序号
# 组成采样段的前粗采样点(前端点:inds-1)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
# min:最大不能大于最后一个采样点序号
# 组成采样段的后粗采样点(后端点:inds)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)

inds_g = torch.stack([below, above], -1)                                # [batch_size, n_samples, 2]
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]       # [batch_size, n_samples, n]
# 获得采样段前后端点的cdf值
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)     # [batch_size, n_samples, 2]
# 获得采样段前后端点的位置
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)   # [batch_size, n_samples, 2]

# 采样段cdf的值(俩个端点cdf的差值)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
# 采样段值太小,则设为1
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)

# 计算占比=(精采样点cdf-前端点cdf)/采样段cdf值
t = (u - cdf_g[..., 0]) / denom
# 通过精采样点在所在采样段cdf长度中的占比,换算出对应精采样点在采样段上的位置
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

精采样点的过程如图所示。

通过计算精采样点的cdf与所处采样段cdf值的占比,换算出对应精采样点在采样段上的位置。采样段cdf值越大说明在该采样段的sdf变化越剧烈,适合插入的精采样点越多。

注意:u并不是代表真实的精采样点cdf值,而是为了方便精采样点更多的插入sdf变化越剧烈的采样段中,cdf值又是由sdf值得累加而来,因此真实的精采样点sdf值是需要在确定插入位置后由SDFNetwork计算出来。

cat_z_vals

batch_size, n_samples = z_vals.shape
_, n_importance = new_z_vals.shape
# 精采样点在世界坐标系下的位置
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]

# 粗采样点和精采样点归并,记录采样点的顺序
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
z_vals, index = torch.sort(z_vals, dim=-1)

if not last:
    # 精采样点的等势面
    new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)        # [batch_size, n_importance]
    # 粗采样点的等势面和精采样点的等势面拼接
    sdf = torch.cat([sdf, new_sdf], dim=-1)     # [batch_size, n_samples+n_importance]
    xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)     # [batch_size*(n_samples+n_importance)]
    index = index.reshape(-1)           # [batch_size*(n_samples+n_importance)]
    # 按照index对采样点sdf进行排序,是采样点位置和sdf值一一匹配
    sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)

确定了精采样点的插入位置后,由SDFNetwork来计算精采样点的sdf值。

小结

源码中精采样点的采样不是一次完成,而是经过多次完成的,这样可以让新一轮的精采样点在已有的采样点基础上更好地逼近物体表面,如下图所示。

在新一轮的up_sample中原本穿过球体的采样段被更加精确细微的采样段所替换。

总结

尽可能简单、详细的介绍NeuS训练阶段部分代码:sdf网络的结构和用途以及在粗采样点基础上对光线rays做更精细的采样。后续会讲解训练阶段的其他代码。

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值