昇思MindSpore应用案例 | 三维重建之神经辐射场 (NeRF)

在这个例子中,我们展示了研究论文NeRF的最小实现:将场景表示为Ben Mildenhall等人的视图合成的神经辐射场。作者提出了一种巧妙的方法,通过神经网络对体积场景函数进行建模,从而合成场景的新视图。

为了帮助您直观地理解这一点,让我们从以下问题开始:是否可以将图像中像素的位置提供给神经网络,并要求网络预测该位置的颜色?

图 1:给定图像坐标的神经网络作为输入并要求预测坐标处的颜色。

神经网络会假设记忆(过拟合)图像。这意味着我们的神经网络会将整个图像编码为其权重。我们可以用每个位置查询神经网络,它最终会重建整个图像。

图 2:经过训练的神经网络从头开始重新创建图像。

现在出现了一个问题,我们如何扩展这个想法来学习3D体积场景?实现与上述类似的过程需要了解每个体素(体积像素)。事实证明,这是一项非常具有挑战性的任务。

该论文的作者提出了一种使用场景的一些图像来学习3D场景的最小而优雅的方法。他们放弃使用体素进行训练。网络学习对体积场景进行建模,从而生成模型在训练时未显示的3D场景的新视图(图像)。

需要了解一些先决条件才能充分理解这一过程。我们以这样一种方式构建示例,以便您在开始实施之前拥有所有必需的知识。

设置

import os

os.environ["GLOG_v"] = "3"

import cv2
import json
import time
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt

import mindspore
import mindspore as md
import mindspore.ops.operations as P
from mindspore import nn, ops, Tensor
import mindspore as md


# environment config
mode = "GRAPH_MODE"
seed = 1
device_id = 0
device = "GPU"
init = False

# data config
half_res = True
testskip = 8
white_bkgd = True
render_test = True

# network config
chunk = 32768
cap_n_samples = 64
cap_n_importance = 0
netchunk = 65536
lrate = 0.0005

# runner config
cap_n_iters = 3000
cap_n_rand = 1024
i_testset = 100
lrate_decay = 250


def context_setup(idx, device='GPU', mode=md.context.GRAPH_MODE):
    if init:
        return
    if device == "CPU":
        raise NotImplementedError("`cumprod` ops does not support CPU")
    md.context.set_context(mode=mode, device_target=device, device_id=idx)

md.set_seed(seed)
context_setup(device_id, device, getattr(md.context, mode))

预备知识

数据文件包含图像、相机姿势和焦距。这些图像是从多个摄像机角度拍摄的,如图3所示。

图 3:多个摄像机角度。

要在这种情况下理解相机姿势,我们必须首先让自己认为相机是现实世界和二维图像之间的映射。

图 4:通过相机将 3-D 世界映射到 2-D 图像。

考虑以下等式:

其中x是2-D图像点,X是3-D世界点,P是相机矩阵。P是 一个3 x 4矩阵,在将现实世界对象映射到图像平面上起着至关重要的作用。

相机矩阵是一个仿射变换矩阵,它与一个3 x 1列连接[image height, image width, focal length]以产生姿势矩阵。该矩阵的尺寸为3 x 5,其中第一个3 x 3块位于相机的视点中。轴是[down, right, backwards]或[-y, x, z]相机面向前方的位置-z。

图 5:仿射变换。

COLMAP帧是[right, down, forwards]或[x, -y, -z]。在此处阅读有关COLMAP的更多信息。

下载数据

首先自行下载数据: NeRF Synthetics.

下载数据后, 按照以下结构安排数据:

.datasets/
└── nerf_synthetics
    └── lego
        ├── test [600 entries exceeds filelimit, not opening dir]
        ├── train [100 entries exceeds filelimit, not opening dir]
        ├── transforms_test.json
        ├── transforms_train.json
        ├── transforms_val.json
        └── val [100 entries exceeds filelimit, not opening dir]

数据加载

datadir = Path("datasets/nerf_synthetics/lego")
train_imgs = list((datadir / "train").glob("*.png"))
num_images = len(train_imgs)

temp_fname = str(train_imgs[np.random.randint(low=0, high=num_images)])
tmp_img = cv2.imread(temp_fname, cv2.IMREAD_UNCHANGED)
tmp_img = cv2.cvtColor(tmp_img, cv2.COLOR_BGRA2RGBA)

plt.imshow(tmp_img)
plt.show()

数据管线

现在您已经了解了相机矩阵的概念以及从3D场景到2D图像的映射,让我们来谈谈逆映射,即从2D图像到3D场景。

我们需要讨论使用光线投射和追踪的体积渲染,这是常见的计算机图形技术。本节将帮助您快速掌握这些技术。

考虑一个带有N像素的图像。我们通过每个像素射出一条射线,并在射线上采样一些点。射线通常由方程参数化,r(t) =o+td其中t是参数,o是原点并且d是单位方向矢量。

我们考虑一条射线,并在射线上采样一些随机点。这些采样点每个都有一个独特的位置(x, y, z),并且光线有一个视角(theta, phi)。视角特别有趣,因为我们可以通过许多不同的方式通过单个像素拍摄光线,每种方式都有独特的视角。这里要注意的另一件有趣的事情是添加到采样过程中的噪声。我们为每个样本添加均匀的噪声,使样本对应于连续分布。这些采样点作为NeRF模型的输入。然后要求模型预测该点的RGB颜色和体积密度。

图 6:数据管线

def trans_t(t):
    return np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]],
                    dtype=np.float32)


def rot_phi(phi):
    return np.array([[1, 0, 0, 0], [0, np.cos(phi), -np.sin(phi), 0],
                     [0, np.sin(phi), np.cos(phi), 0], [0, 0, 0, 1]],
                    dtype=np.float32)


def rot_theta(th):
    return np.array([[np.cos(th), 0, -np.sin(th), 0], [0, 1, 0, 0],
                     [np.sin(th), 0, np.cos(th), 0], [0, 0, 0, 1]],
                    dtype=np.float32)


def pose_spherical(theta, phi, radius):
    """pose_spherical"""
    c2w = trans_t(radius)
    c2w = np.matmul(rot_phi(phi / 180. * np.pi), c2w)
    c2w = np.matmul(rot_theta(theta / 180. * np.pi), c2w)
    c2w = np.matmul(
        np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]],
                 dtype=np.float32), c2w)
    return c2w


def load_blender_data(basedir, half_res=False, testskip=1):
    """load_blender_data"""
    splits = ['train', 'val', 'test']
    metas = {}
    for s in splits:
        with open(os.path.join(basedir, f'transforms_{s}.json'),
                  'r',
                  encoding='utf-8') as fp:
            metas[s] = json.load(fp)

    all_imgs = []
    all_poses = []
    counts = [0]
    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []
        if s == 'train' or testskip == 0:
            skip = 1
        else:
            skip = testskip

        for frame in meta['frames'][::skip]:
            fname = os.path.join(basedir, frame['file_path'] + '.png')
            tmp_img = cv2.imread(fname, cv2.IMREAD_UNCHANGED)
            tmp_img = cv2.cvtColor(tmp_img, cv2.COLOR_BGRA2RGBA)
            imgs.append(tmp_img)
            poses.append(np.array(frame['transform_matrix']))
        imgs = (np.array(imgs) / 255.).astype(
            np.float32)  # keep all 4 channels (RGBA)
        poses = np.array(poses).astype(np.float32)
        counts.append(counts[-1] + imgs.shape[0])
        all_imgs.append(imgs)
        all_poses.append(poses)

    i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]

    imgs = np.concatenate(all_imgs, 0)
    poses = np.concatenate(all_poses, 0)

    cap_h, cap_w = imgs[0].shape[:2]
    camera_angle_x = float(meta['camera_angle_x'])
    focal = .5 * cap_w / np.tan(.5 * camera_angle_x)

    render_poses = np.stack([
        pose_spherical(angle, -30.0, 4.0)
        for angle in np.linspace(-180, 180, 40 + 1)[:-1]
    ], axis=0)

    if half_res:
        cap_h = cap_h // 8
        cap_w = cap_w // 8
        focal = focal / 8.

        imgs_half_res = np.zeros((imgs.shape[0], cap_h, cap_w, 4))
        for i, img in enumerate(imgs):
            imgs_half_res[i] = cv2.resize(img, (cap_h, cap_w),
                                          interpolation=cv2.INTER_AREA)
        imgs = imgs_half_res

    return md.Tensor(imgs).astype("float32"), md.Tensor(poses).astype(
        "float32"), md.Tensor(render_poses).astype("float32"), [
            cap_h, cap_w, focal
        ], i_split


images, poses, render_poses, hwf, i_split = load_blender_data(datadir, half_res, testskip)
print('Loaded blender', images.shape, render_poses.shape, hwf, datadir)
i_train, i_val, i_test = i_split

near = 2.
far = 6.

if white_bkgd:
    images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:])
else:
    images = images[..., :3]

if render_test:
    render_poses = poses[i_test.tolist()]

# Cast intrinsics to right types
cap_h, cap_w, focal = hwf
cap_h, cap_w = int(cap_h), int(cap_w)

hwf = [cap_h, cap_w, focal]

NeRF 模型

该模型是一个多层感知器 (MLP),以ReLU作为其非线性。

论文摘录:

“我们通过限制网络将体积密度sigma预测为仅位置的函数来鼓励表示是多视图一致的x,同时允许将RGB颜色c预测为位置和查看方向的函数。为此,MLP首先使用8个全连接层(使用ReLU激活和每层256个通道)处理输入3D坐标,并输出sigma和256维特征向量。然后将该特征向量与相机光线的观察方向连接并传递到一个额外的全连接层(使用ReLU激活和128个通道),输出与视图相关的RGB颜色。”

图 7:NeRF 模型

class NeRFMLP(nn.Cell):
    """
    NeRF MLP architecture.

    Args:
        cap_d (int, optional): Model depth. Default: 8.
        cap_w (int, optional): Model width. Default: 256.
        input_ch (int, optional): Input channel. Default: 3.
        input_ch_views (int, optional): Input view channel. Default: 3.
        output_ch (int, optional): Output channel. Default: 4.
        skips (tuple, optional): Skip connection layer index. Default: (4).
        use_view_dirs (bool, optional): Use view directions or not. Default: False.

    Inputs:
        - **x** (Tensor) - Query tensors. points and view directions (..., 6).

    Outputs:
        Tensor, query features (..., feature_dims).

    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``

    Examples:
        >>> model = NeRFMLP()
        >>> inputs = ms.numpy.randn(1, 3)
        >>> outputs = model(inputs)
        [[0.2384 0.8456 0.6845 0.1584]]
    """

    def __init__(
            self,
            cap_d=8,
            cap_w=256,
            input_ch=3,
            input_ch_views=3,
            output_ch=4,
            skips=(4),
            use_view_dirs=False,
    ):
        super().__init__()
        self.cap_d = cap_d
        self.cap_w = cap_w
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_view_dirs = use_view_dirs

        self.pts_linears = nn.CellList([nn.Dense(in_channels=input_ch, out_channels=cap_w)] + [
            nn.Dense(in_channels=cap_w, out_channels=cap_w) if i not in
            self.skips else nn.Dense(in_channels=cap_w + input_ch, out_channels=cap_w) for i in range(cap_d - 1)
        ])

        self.views_linears = nn.CellList([nn.Dense(in_channels=input_ch_views + cap_w, out_channels=cap_w // 2)])

        if use_view_dirs:
            self.feature_linear = nn.Dense(in_channels=cap_w, out_channels=cap_w)
            self.alpha_linear = nn.Dense(in_channels=cap_w, out_channels=1)
            self.rgb_linear = nn.Dense(in_channels=cap_w // 2, out_channels=3)
        else:
            self.output_linear = nn.Dense(in_channels=cap_w, out_channels=output_ch)

    def construct(self, x):
        """NeRF MLP construct"""
        input_pts, input_views = x[..., :self.input_ch], x[..., self.input_ch:]
        h = input_pts
        for i, _ in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = P.ReLU()(h)
            if i in self.skips:
                h = P.Concat(-1)([input_pts, h])

        if self.use_view_dirs:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            h = P.Concat(-1)([feature, input_views])

            for i, _ in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = P.ReLU()(h)

            rgb = self.rgb_linear(h)
            outputs = P.Concat(-1)([rgb, alpha])
        else:
            outputs = self.output_linear(h)

        return outputs

可微分渲染器

class VolumeRenderer(nn.Cell):
    """
    Volume Renderer architecture.

    Args:
        chunk (int): Number of rays processed in parallel, decrease if running out of memory.
        cap_n_samples (int): Number of coarse samples per ray for coarse net.
        cap_n_importance (int): Number of additional fine samples per ray for fine net.
        net_chunk (int): Number of pts sent through network in parallel, decrease if running out of memory.
        white_bkgd (bool): Set to render synthetic data on a white background (always use for DeepVoxels).
        model_coarse (nn.Cell): Coarse net.
        model_fine (nn.Cell, optional): Fine net, or None.
        embedder_p (Dict): Config for positional encoding for point.
        embedder_d (Dict): Config for positional encoding for view direction.
        near (float, optional): The near plane. Default: 0.0.
        far (float, optional): The far plane. Default: 1e6.

    Inputs:
        - **rays** (Tensor) - The ray tensor. (..., num_pts_per_ray, ray_batch_dims).

    Outputs:
        Tuple of 2 Tensor, the output tensors.

        - **fine_net_output** (Tensor, optional) - The fine net output features.
        - **coarse_net_output** (Tensor) - The coarse net output features.

    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``

    Examples:
        >>> model = VolumeRenderer(1000, 6, 12, 1000, False, P.Identity(), P.Identity(), P.Identity(), P.Identity())
        >>> inputs = ms.numpy.randn(1, 1, 3)
        >>> outputs = model(inputs)
        ([[0.2384 0.8456 0.1273]], [[0.8653 0.1866 0.6382]])
    """

    def __init__(self,
                 chunk,
                 cap_n_samples,
                 cap_n_importance,
                 net_chunk,
                 white_bkgd,
                 model_coarse,
                 model_fine,
                 embedder_p,
                 embedder_d,
                 near=0.0,
                 far=1e6):
        super().__init__()

        self.chunk = chunk
        self.cap_n_samples = cap_n_samples
        self.cap_n_importance = cap_n_importance
        self.net_chunk = net_chunk
        self.white_bkgd = white_bkgd

        self.model_coarse = model_coarse
        self.model_fine = model_fine
        # embedder for positions
        self.embedder_p = Embedder(**embedder_p)
        # embedder for view-in directions
        self.embedder_d = Embedder(**embedder_d)

        self.near = near
        self.far = far

    def construct(self, rays):
        """Volume renderer construct."""
        return self.inference(rays)

    def inference(self, rays):
        """Volume renderer inference."""
        # make the number of rays be multiple of the chunk size
        cap_n_rays = (rays.shape[1] // self.chunk + 1) * self.chunk
        cap_n = self.cap_n_samples

        res_ls = {"rgb_map_coarse": [], "rgb_map_fine": []}

        for i in range(0, cap_n_rays, self.chunk):
            ray_origins, ray_dirs = rays[:, i:i + self.chunk, :]
            reshape_op = mindspore.ops.Reshape()
            view_dirs = reshape_op(
                ray_dirs / mindspore.numpy.norm(ray_dirs, axis=-1, keepdims=True),
                (-1, 3),
            )

            near, far = self.near * mindspore.numpy.ones_like(ray_dirs[..., :1]), self.far * mindspore.numpy.ones_like(
                ray_dirs[..., :1])
            cap_m = ray_origins.shape[0]
            if cap_m == 0:
                continue

            # stratified sampling along rays
            s_samples = sample_along_rays(near, far, cap_n)

            # position samples along rays
            unsqueeze_op = P.ExpandDims()
            pos_samples = unsqueeze_op(ray_origins,
                                       1) + unsqueeze_op(ray_dirs, 1) * unsqueeze_op(s_samples, 2)
            # expand ray directions to the same shape of samples
            expand_op = P.BroadcastTo(pos_samples.shape)
            dir_samples = expand_op(unsqueeze_op(view_dirs, 1))

            reshape_op = P.Reshape()
            pos_samples = reshape_op(pos_samples, (-1, 3))
            dir_samples = reshape_op(dir_samples, (-1, 3))

            # retrieve optic data from the network
            optic_d = self._run_network_model_coarse(pos_samples, dir_samples)
            optic_d = mindspore.numpy.reshape(optic_d, [cap_m, cap_n, 4])

            # composite optic data to generate a RGB image
            rgb_map_coarse, weights_coarse = self._composite(optic_d, s_samples, ray_dirs)

            if self.cap_n_importance > 0:
                z_vals_mid = 0.5 * (s_samples[..., 1:] + s_samples[..., :-1])
                z_samples = sample_pdf(z_vals_mid, weights_coarse[..., 1:-1], self.cap_n_importance)
                z_samples = mindspore.ops.stop_gradient(z_samples)

                sort_op = P.Sort(axis=-1)
                z_vals, _ = sort_op(P.Concat(-1)([s_samples, z_samples]))
                pts = (ray_origins[..., None, :] + ray_dirs[..., None, :] * z_vals[..., :, None]
                      )

                expand_op_2 = P.BroadcastTo(pts.shape)
                dir_samples = expand_op_2(unsqueeze_op(view_dirs, 1))

                pts = reshape_op(pts, (-1, 3))
                dir_samples = reshape_op(dir_samples, (-1, 3))

                optic_d = self._run_network_model_fine(pts, dir_samples)
                optic_d = reshape_op(optic_d, (cap_m, cap_n + self.cap_n_importance, 4))

                rgb_map_fine, _ = self._composite(optic_d, z_vals, ray_dirs)
            else:
                rgb_map_fine = rgb_map_coarse

            res_ls["rgb_map_coarse"].append(rgb_map_coarse)
            res_ls["rgb_map_fine"].append(rgb_map_fine)

        res = {}
        for k, v in res_ls.items():
            res[k] = P.Concat(0)(v)

        return res["rgb_map_fine"], res["rgb_map_coarse"]

    def _run_network_model_fine(self, pts, view_dirs):
        """Run fine model."""
        inputs_flat = pts
        embedded = self.embedder_p(inputs_flat)

        if view_dirs is not None:
            input_dirs_flat = view_dirs
            embedded_dirs = self.embedder_d(input_dirs_flat)
            embedded = P.Concat(-1)([embedded, embedded_dirs])

        chunk = self.net_chunk
        outputs_flat_ls = []
        for i in range(0, embedded.shape[0], chunk):
            outputs_flat_ls.append(self.model_fine(embedded[i:i + chunk]))
        outputs_flat = P.Concat(0)(outputs_flat_ls)
        return outputs_flat

    def _run_network_model_coarse(self, pts, view_dirs):
        """Run coarse model."""
        inputs_flat = pts
        embedded = self.embedder_p(inputs_flat)

        if view_dirs is not None:
            input_dirs_flat = view_dirs
            embedded_dirs = self.embedder_d(input_dirs_flat)
            embedded = P.Concat(-1)([embedded, embedded_dirs])

        chunk = self.net_chunk
        outputs_flat_ls = []
        for i in range(0, embedded.shape[0], chunk):
            outputs_flat_ls.append(self.model_coarse(embedded[i:i + chunk]))
        outputs_flat = P.Concat(0)(outputs_flat_ls)
        return outputs_flat

    def _transfer(self, optic_d, dists):
        """Transfer occupancy to alpha values."""
        sigmoid = P.Sigmoid()
        rgbs = sigmoid(optic_d[..., :3])
        alphas = 1.0 - P.Exp()(-1.0 * (P.ReLU()(optic_d[(..., 3)])) * dists)

        return rgbs, alphas

    def _composite(self, optic_d, s_samples, rays_d):
        """Composite the colors and densities."""
        # distances between each samples
        dists = s_samples[..., 1:] - s_samples[..., :-1]
        dists_list = (
            dists,
            (mindspore.numpy.ones([]) * 1e10).expand_as(dists[..., :1]),
        )
        dists = P.Concat(-1)(dists_list)

        dists = dists * mindspore.numpy.norm(rays_d[..., None, :], axis=-1)

        # retrieve display colors and alphas for each samples by a transfer function
        rgbs, alphas = self._transfer(optic_d, dists)

        weights = alphas * mindspore.numpy.cumprod(
            P.Concat(-1)([mindspore.numpy.ones((alphas.shape[0], 1)), 1.0 - alphas + 1e-10])[:, :-1],
            axis=-1,
        )
        sum_op = mindspore.ops.ReduceSum()
        rgb_map = sum_op(weights[..., None] * rgbs, -2)
        acc_map = sum_op(weights, -1)

        if self.white_bkgd:
            rgb_map = rgb_map + (1.0 - acc_map[..., None])

        return rgb_map, weights

class Embedder(nn.Cell):
    """
    Embedder for positional embedding.

    Args:
        input_dims (int): Input dimensions.
        max_freq_pow (float): Maximum frequency pow.
        num_freqs (int): Number of frequencies.
        periodic_fns (list, optional): Periodic fns. Default: [mindspore.ops.Sin(), mindspore.ops.Cos()].
        log_sampling (bool, optional): Log sampling. Default: True.
        include_input (bool, optional): Include input or not. Default: True.

    Inputs:
        inputs (Tensor) - Input tensor.

    Outputs:
        Tensor, input concatenated with positional embeddings.

    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``

    Examples:
        >>> model = Embedder(1, 1)
        >>> inputs = ms.numpy.randn(1)
        >>> outputs = model(inputs)
        [0.1384 0.4426]
    """

    def __init__(
            self,
            input_dims,
            max_freq_pow,
            num_freqs,
            periodic_fns=(mindspore.ops.Sin(), mindspore.ops.Cos()),
            log_sampling=True,
            include_input=True,
    ):
        super().__init__()

        embed_fns = []
        out_dims = 0
        if include_input:
            embed_fns.append(mindspore.ops.Identity())
            out_dims += input_dims

        if log_sampling:
            freq_bands = mindspore.Tensor(2.0)**mindspore.numpy.linspace(0.0, max_freq_pow, num=num_freqs)
        else:
            freq_bands = mindspore.numpy.linspace(2.0**0.0, 2.0**max_freq_pow, num=num_freqs)

        for _ in freq_bands:
            for p_fn in periodic_fns:
                embed_fns.append(p_fn)
                out_dims += input_dims

        self.embed_fns = embed_fns
        self.out_dims = out_dims

        self.freq_bands = freq_bands

    def construct(self, inputs):
        """Embedder construct."""
        out = []
        for i, fn in enumerate(self.embed_fns):
            if i == 0:
                out.append(fn(inputs))
            else:
                out.append(fn(inputs * self.freq_bands[(i - 1) // 2]))
        return P.Concat(-1)(out)

定义模型

def get_embedder(multi_res, i=0):
    """
    Get embedder function.

    Args:
        multi_res (int): Log2 of max freq for positional encoding.
        i (int, optional): Set 0 for default positional encoding, -1 for none. Default: 0.

    Returns:
        Tuple of nn.Cell and int, embedder and the output dimensions.

        - **embedder** (nn.Cell) - The embedder.
        - **out_dims** (int) - The output dimensions.
    """
    if i == -1:
        return md.ops.Identity(), 3

    embed_kwargs = {
        "include_input": True,
        "input_dims": 3,
        "max_freq_pow": multi_res - 1,
        "num_freqs": multi_res,
        "log_sampling": True,
        "periodic_fns": [md.ops.Sin(), md.ops.Cos()],
    }

    embedder_obj = Embedder(**embed_kwargs)
    embed = embed_kwargs
    return embed, embedder_obj.out_dims


def create_nerf(multires=10,
                i_embed=0,
                multires_views=4,
                netdepth=8,
                netwidth=256,
                netdepth_fine=8,
                netwidth_fine=256,
                use_view_dirs=True,
                cap_n_importance=0,
                ckpt_path=None):
    """create nerf model and load weights"""
    embed_fn, input_ch = get_embedder(multires, i_embed)

    input_ch_views = 0
    embeddirs_fn = None
    if use_view_dirs:
        embeddirs_fn, input_ch_views = get_embedder(multires_views,
                                                    i_embed)
    # Create networks
    output_ch = 4
    skips = [4]
    model_coarse = NeRFMLP(cap_d=netdepth,
                           cap_w=netwidth,
                           input_ch=input_ch,
                           output_ch=output_ch,
                           skips=skips,
                           input_ch_views=input_ch_views,
                           use_view_dirs=use_view_dirs)
    grad_vars = [{"params": model_coarse.trainable_params()}]

    model_fine = None
    if cap_n_importance > 0:
        model_fine = NeRFMLP(cap_d=netdepth_fine,
                             cap_w=netwidth_fine,
                             input_ch=input_ch,
                             output_ch=output_ch,
                             skips=skips,
                             input_ch_views=input_ch_views,
                             use_view_dirs=use_view_dirs)
        grad_vars += [{"params": model_fine.trainable_params()}]

    optimizer = None
    start_iter = 0

    # Load checkpoints
    if ckpt_path is not None:
        print("Reloading from", ckpt_path)
        ckpt = md.load_checkpoint(ckpt_path)

        # Load training steps
        start_iter = int(ckpt["global_steps"]) + 1

        # Load network weights
        md.load_param_into_net(
            model_coarse,
            {key: value for key, value in ckpt.items() if ".model_coarse." in key},
        )
        if model_fine is not None:
            md.load_param_into_net(
                model_fine,
                {key: value for key, value in ckpt.items() if ".model_fine." in key},
            )
    else:
        print("No ckpt reloaded")

    return start_iter, optimizer, model_coarse, model_fine, embed_fn, embeddirs_fn

class RendererWithCriterion(nn.Cell):
    """
    Renderer with criterion.

    Args:
        renderer (nn.Cell): Renderer.
        loss_fn (nn.Cell, optional): Loss function. Default: nn.MSELoss().

    Inputs:
        - **rays** (Tensor) - Rays tensor.
        - **gt** (Tensor) - Ground truth tensor.

    Outputs:
        Tensor, loss for one forward pass.
    """
    def __init__(self, renderer, loss_fn=nn.MSELoss()):
        """Renderer with criterion."""
        super().__init__()
        self.renderer = renderer
        self.loss_fn = loss_fn

    def construct(self, rays, gt):
        """Renderer Trainer construct."""
        rgb_map_fine, rgb_map_coarse = self.renderer(rays)
        return self.loss_fn(rgb_map_fine, gt) + self.loss_fn(rgb_map_coarse, gt)


# Create nerf model
start_iter, optimizer, model_coarse, model_fine, embed_fn, embeddirs_fn = create_nerf()
# Training steps
global_steps = start_iter
# Create volume renderer
renderer = VolumeRenderer(chunk, cap_n_samples,
                          cap_n_importance, netchunk,
                          white_bkgd, model_coarse, model_fine,
                          embed_fn, embeddirs_fn, near, far)

renderer.model_coarse.to_float(md.dtype.float16)
if renderer.model_fine is not None:
    renderer.model_fine.to_float(md.dtype.float16)

renderer_with_criteron = RendererWithCriterion(renderer)
optimizer = nn.Adam(params=renderer.trainable_params(),
                    learning_rate=lrate,
                    beta1=0.9,
                    beta2=0.999)

grad_scale = ops.MultitypeFuncGraph("grad_scale")
@grad_scale.register("Tensor", "Tensor")
def gradient_scale(scale, grad):
    return grad * ops.cast(scale, ops.dtype(grad))

class CustomTrainOneStepCell(nn.TrainOneStepCell):
    def __init__(self, network, optimizer, sens=1.0):
        super(CustomTrainOneStepCell, self).__init__(network, optimizer, sens)
        self.hyper_map = ops.HyperMap()
        self.reciprocal_sense = Tensor(1 / sens, mindspore.float32)

    def scale_grad(self, gradients):
        gradients = self.hyper_map(ops.partial(grad_scale, self.reciprocal_sense), gradients)
        return gradients

    def construct(self, *inputs):
        loss = self.network(*inputs)
        sens = ops.fill(loss.dtype, loss.shape, self.sens)
        # calculate gradients, the sens will equal to the loss_scale
        grads = self.grad(self.network, self.weights)(*inputs, sens)
        # gradients / loss_scale
        grads = self.scale_grad(grads)
        # reduce gradients in distributed scenarios
        grads = self.grad_reducer(grads)
        loss = ops.depend(loss, self.optimizer(grads))
        return loss

loss_scale = 1024.0
train_renderer = CustomTrainOneStepCell(renderer_with_criteron, optimizer, loss_scale)
train_renderer.set_train()

训练

def generate_rays(h, w, f, pose):
    '''
    Given an image plane, generate rays from the camera origin to each pixel on the image plane.

    Arguments:
        h: height of the image plane.
        w: width of the image plane.
        f: focal length of the image plane.
        pose: the extrinsic parameters of the camera. (3, 4) or (4, 4)

    Returns:
        A tuple: origins of rays, directions of rays
    '''

    # Coordinates of the 2D grid
    cols = md.ops.ExpandDims()(
        md.numpy.linspace(-1.0 * w / 2, w - 1 - w / 2, w) / f,
        0).repeat(h, axis=0)  # (h, w)
    rows = md.ops.ExpandDims()(
        -1.0 * md.numpy.linspace(-1.0 * h / 2, h - 1 - h / 2, h) / f,
        1).repeat(w, axis=1)  # (h, w)

    # Ray directions for all pixels
    ray_dirs = md.numpy.stack([cols, rows, -1.0 * md.numpy.ones_like(cols)],
                              axis=-1)  # (h, w, 3)
    # Apply rotation transformation to make each ray orient according to the camera
    unsqueeze_op = md.ops.ExpandDims()
    ray_dirs = md.numpy.sum(unsqueeze_op(ray_dirs, 2) * pose[:3, :3], axis=-1)
    # Origin position
    rays_oris = pose[:3, -1].expand_as(ray_dirs)  # (h, w, 3)

    return rays_oris, ray_dirs.astype(pose.dtype)  # (h, w, 3), (h, w, 3)


def train_net(iter_, train_renderer, optimizer, rays, gt):
    '''
    Train a network.

    Arguments:
        config: configuration.
        iter_: current iterations.
        renderer: a volume renderer.
        optimizer: a network optimizer.
        rays: a batch of rays for training. (#rays * #samples, 6)
        gt: the groundtruth.

    Returns:
        A tuple: (MSE loss, PSNR).
    '''
    loss = train_renderer(rays, gt)

    # Update learning rate
    decay_rate = 0.1
    decay_steps = lrate_decay * 1000
    new_lrate = lrate * (decay_rate**(iter_ / decay_steps))
    optimizer.learning_rate = md.Parameter(new_lrate)

    return float(loss), float(psnr_from_mse(loss))


def test_net(img_h,
             img_w,
             focal,
             renderer,
             test_poses,
             gt=None,
             on_progress=None,
             on_complete=None):
    '''
    Test the network and generate results.

    Arguments:
        img_h: height of image plane.
        img_w: width of image plane.
        focal: focal length.
        renderer: the volume renderer.
        test_poses: poses used to test the network. (#poses, 4, 4)
        on_progress: a callback function invoked per generation of a result.
        on_complete: a callback function invoked after generating all results.

    Returns:
        A tuple: (Mean test time, MSE loss, PSNR).
    '''

    rgb_maps = []
    loss_ls = []
    psnr_ls = []
    time_ls = []

    reshape_op = md.ops.Reshape()
    stack_op = md.ops.Stack(axis=0)

    image_list = []

    for j, test_pose in enumerate(test_poses):
        t0 = time.time()

        # Generate rays for all pixels
        ray_oris, ray_dirs = generate_rays(img_h, img_w, focal, test_pose)
        ray_oris = reshape_op(ray_oris, (-1, 3))
        ray_dirs = reshape_op(ray_dirs, (-1, 3))
        test_batch_rays = stack_op([ray_oris, ray_dirs])

        # Retrieve testing results
        rgb_map, _ = renderer.inference(test_batch_rays)
        rgb_map = reshape_op(rgb_map, (img_h, img_w, 3))
        rgb_maps.append(rgb_map.asnumpy())

        # If given groundtruth, compute MSE and PSNR
        if gt is not None:
            loss = mse(rgb_map, gt[j])
            psnr = psnr_from_mse(loss)
            loss_ls.append(float(loss))
            psnr_ls.append(float(psnr))

        time_ls.append(time.time() - t0)

        # Handle each testing result
        if on_progress:
            if isinstance(on_progress, list):
                on_progress[0](j, rgb_maps[-1])
                if gt is not None:
                    on_progress[1](j, gt[j].asnumpy())
            else:
                on_progress(j, rgb_maps[-1])

        image_list.append(to8b(rgb_maps[-1]))

    # Handle all testing results
    if on_complete:
        on_complete(np.stack(rgb_maps, 0))

    if not loss_ls:
        loss_ls = [0.0]
    if not psnr_ls:
        psnr_ls = [0.0]
    if not time_ls:
        time_ls = [0.0]

    return np.mean(time_ls), np.mean(loss_ls), np.mean(psnr_ls), image_list


def to8b(x):
    """Convert normalized color to 8-bit color"""
    return (255 * np.clip(x, 0.0, 1.0)).astype(np.uint8)


def mse(im1, im2):
    '''
    MSE between two images.
    '''

    return md.numpy.mean((im1 - im2)**2)


psnr_from_mse_base = md.Tensor([10.0])


def psnr_from_mse(v):
    '''
    Convert MSE to PSNR.
    '''
    return -10.0 * (md.numpy.log(v) / md.numpy.log(psnr_from_mse_base))


def sample_grid_2d(cap_h, cap_w, cap_n):
    """
    Sample cells in an cap_h x cap_w mesh grid.

    Args:
        cap_h (int): Height of the mesh grid.
        cap_w (int): Width of the mesh grid.
        cap_n (int): The number of samples.

    Returns:
        Tuple of 2 Tensor, sampled rows and sampled columns.

        - **select_coords_x** (Tensor) - Sampled rows.
        - **select_coords_y** (Tensor) - Sampled columns.
    """
    if cap_n > cap_w * cap_h:
        cap_n = cap_w * cap_h

    # Create a 2D mesh grid where each element is the coordinate of the cell
    stack_op = md.ops.Stack(-1)
    coords = stack_op(
        md.numpy.meshgrid(
            md.numpy.linspace(0, cap_h - 1, cap_h),
            md.numpy.linspace(0, cap_w - 1, cap_w),
            indexing="ij",
        ))
    # Flat the mesh grid
    coords = md.ops.Reshape()(coords, (-1, 2))
    # Sample N cells in the mesh grid
    select_indexes = np.random.choice(coords.shape[0],
                                      size=[cap_n],
                                      replace=False)
    # Sample N cells among the mesh grid
    select_coords = coords[select_indexes.tolist()].astype("int32")

    return select_coords[:, 0], select_coords[:, 1]


def sample_along_rays(near,
                      far,
                      cap_cap_n_samples,
                      lin_disp=False,
                      perturb=True):
    """
    Sample points along rays.

    Args:
        near (Tensor): A vector containing nearest point for each ray. (cap_n_rays).
        far (Tensor): A vector containing furthest point for each ray. (cap_n_rays).
        cap_n_samples (int): The number of sampled points for each ray.
        lin_disp (bool): True for sample linearly in inverse depth rather than in depth (used for some datasets).
        perturb (bool): True for stratified sampling. False for uniform sampling.

    Returns:
        Tensor, samples where j-th component of the i-th row is the j-th sampled position along the i-th ray.
    """
    # The number of rays
    cap_n_rays = near.shape[0]

    # Uniform samples along rays
    t_vals = md.numpy.linspace(0.0, 1.0, num=cap_cap_n_samples)
    if not lin_disp:
        z_vals = near * (1.0 - t_vals) + far * t_vals
    else:
        z_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals)

    expand_op = md.ops.BroadcastTo((cap_n_rays, cap_cap_n_samples))
    z_vals = expand_op(z_vals)

    if perturb:
        # Get intervals between samples
        mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
        cat_op = md.ops.Concat(-1)
        upper = cat_op([mids, z_vals[..., -1:]])
        lower = cat_op([z_vals[..., :1], mids])
        # Stratified samples in those intervals
        t_rand = md.numpy.rand(z_vals.shape)
        z_vals = lower + (upper - lower) * t_rand

    return z_vals


def sample_pdf(bins, weights, cap_cap_n_samples, det=False):
    """
    Sample pdf function.

    Args:
        bins (int): The number of bins for pdf.
        weights (Tensor): The estimated weights.
        cap_cap_n_samples (int): The number of points to be sampled.
        det (bool, optional): Deterministic run or not. Default: False.

    Returns:
        Tensor, sampled pdf tensor.
    """
    weights = weights + 1e-5
    pdf = weights / md.numpy.sum(weights, -1, keepdims=True)
    cdf = md.numpy.cumsum(pdf, -1)
    cdf = md.ops.Concat(-1)([md.numpy.zeros_like(cdf[..., :1]), cdf])

    # Take uniform samples
    temp_shape = cdf.shape[:-1]
    cap_cap_n_samples_new = cap_cap_n_samples
    temp_shape_new = list(temp_shape) + [cap_cap_n_samples_new]
    if det:
        u = md.numpy.linspace(0.0, 1.0, num=cap_cap_n_samples)
        expand_op = md.ops.BroadcastTo(temp_shape_new)
        u = expand_op(u)
    else:
        u = md.numpy.rand(temp_shape_new)

    # Invert CDF
    indexes = nd_searchsorted(cdf, u)

    below = md.numpy.maximum(md.numpy.zeros_like(indexes - 1), indexes - 1)
    above = md.numpy.minimum((cdf.shape[-1] - 1) * md.numpy.ones_like(indexes),
                             indexes)
    indexes_g = md.ops.Stack(axis=-1)([below, above])

    matched_shape = (indexes_g.shape[0], indexes_g.shape[1], cdf.shape[-1])
    gather_op = md.ops.GatherD()
    unsqueeze_op = md.ops.ExpandDims()
    expand_op = md.ops.BroadcastTo(matched_shape)
    cdf_g = gather_op(expand_op(unsqueeze_op(cdf, 1)), 2, indexes_g)
    bins_g = gather_op(expand_op(unsqueeze_op(bins, 1)), 2, indexes_g)

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = md.numpy.where(denom < 1e-5, md.numpy.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples


def nd_searchsorted(cdf, u):
    """N-dim searchsorted.

    Args:
        cdf (Tensor): The cdf sampling weights.
        u (Tensor): The interval tensors.

    Returns:
        Tensor, index after searchsorted ops.
    """
    spatial_shape = cdf.shape[:-1]
    last_dim_cdf, last_dim_u = cdf.shape[-1], u.shape[-1]
    cdf_, u_ = cdf.view(-1, last_dim_cdf), u.view(-1, last_dim_u)
    indexes_ls = []

    for i in range(cdf_.shape[0]):
        indexes_ls.append(cdf_[i].searchsorted(u_[i], side="right"))
    indexes = md.ops.Stack(axis=0)(indexes_ls)
    indexes = indexes.view(*spatial_shape, last_dim_u)
    return indexes

train_image_list = []
with tqdm(range(1, cap_n_iters + 1)) as pbar:
    pbar.n = start_iter

    for i in pbar:
        # Show progress
        pbar.set_description(f'Iter {global_steps + 1:d}')
        pbar.update()

        # Start time of the current iteration
        time0 = time.time()

        img_i = int(np.random.choice(i_train))

        target = images[img_i]
        pose = poses[img_i, :3, :4]

        if cap_n_rand is not None:
            rays_o, rays_d = generate_rays(
                cap_h, cap_w, focal,
                pose)  # (cap_h, cap_w, 3), (cap_h, cap_w, 3)
            sampled_rows, sampled_cols = sample_grid_2d(
                cap_h, cap_w, cap_n_rand)
            rays_o = rays_o[sampled_rows, sampled_cols]  # (cap_n_rand, 3)
            rays_d = rays_d[sampled_rows, sampled_cols]  # (cap_n_rand, 3)

            batch_rays = md.ops.Stack(axis=0)([rays_o, rays_d])
            target_s = target[sampled_rows, sampled_cols]  # (cap_n_rand, 3)

        loss, psnr = train_net(global_steps, train_renderer, optimizer,
                               batch_rays, target_s)

        pbar.set_postfix(time=time.time() - time0, loss=loss, psnr=psnr)

        # Save testing results
        if (global_steps + 1) % i_testset == 0:
            test_idx = np.random.randint(low=0, high=len(i_test))
            test_time, test_loss, test_psnr, sub_train_image_list = test_net(
                cap_h, cap_w, focal, renderer,
                poses[i_test[test_idx:test_idx + 1].tolist()],
                images[i_test[test_idx:test_idx + 1].tolist()])
            train_image_list.extend(sub_train_image_list)
        global_steps += 1
import matplotlib.pyplot as plt
import matplotlib.animation as animation
%matplotlib inline

def showGif(image_list, name):
    show_list = []
    fig = plt.figure(figsize=(4, 4), dpi=120)
    for epoch in range(len(image_list)):
        plt.axis("off")
        show_list.append([plt.imshow(image_list[epoch])])

    ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
    ani.save(f'images/{name}.gif', writer='pillow', fps=1)

showGif(train_image_list, "nerf.train")

渲染

test_time, test_loss, test_psnr, test_image_list = test_net(
    cap_h,
    cap_w,
    focal,
    renderer,
    md.Tensor(poses[i_test.tolist()]),
    images[i_test.tolist()])
print(
    f"Testing results: [ Mean Time: {test_time:.4f}s, Loss: {test_loss:.4f}, PSNR: {test_psnr:.4f} ]"
)
showGif(test_image_list, "nerf.test")

推理

使用预训练权重推理最佳结果。建议重启 jupyter kernel 后运行以下代码块,避免内存不足问题。

device = "GPU"
device_id = 0
mode = "GRAPH_MODE"

# 下载权重
!if [[ ! -f ckpts/200000.blender_lego_coarse_0.ckpt ]]; then mkdir -p ckpts/ && wget -P ckpts/ https://download.mindspore.cn/vision/nerf/lego/200000.blender_lego_coarse_0.ckpt; fi
# 运行推理脚本
!cd src && GLOG_v=3 python eval.py --name lego_coarse --data_dir ../datasets/nerf_synthetics/lego --dataset_type blender --half_res --cap_n_rand 4096 --cap_n_samples 192 --cap
import glob
import imageio
import matplotlib.pyplot as plt
import matplotlib.animation as animation
%matplotlib inline

def showGif(image_list, name):
    show_list = []
    fig = plt.figure(figsize=(4, 4), dpi=120)
    for epoch in range(len(image_list)):
        plt.axis("off")
        show_list.append([plt.imshow(image_list[epoch])])

    ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
    ani.save(f'images/{name}.gif', writer='pillow', fps=1)

inference_image_list = [imageio.imread(image_path) for image_path in sorted(glob.glob("src/results/blender_lego_coarse_0/renderonly_200000/*.png"))]
showGif(inference_image_list, "nerf.inference")

总结

本案例对NeRF的论文中提出的模型进行了详细的解释,向读者完整地展现了该算法的流程。

引用

[1] Mildenhall, Ben et al. “NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis.” ECCV (2020).

参考资料

NeRF仓库:官方NeRF的存储库。

链接:https://github.com/bmild/nerf

NeRF 视频:关于 NeRF 的讲解视频。

链接:https://www.youtube.com/watch?v=dPWLybp4LL0

更多昇思MindSpore应用案例请访问官网开发者案例:https://www.mindspore.cn/resources/cases

MindSpore官方资料

官方QQ群 : 871543426

官网:https://www.mindspore.cn/

Gitee : https://gitee.com/mindspore/mindspore

GitHub : https://github.com/mindspore-ai/mindspore

论坛:https://www.hiascend.com/forum/forum-0106101385921175002-1.html

Openl启智社区:https://openi.org.cn 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值