Nerf(Neural Radiance Fields)是一种用于三维重建和图像合成的机器学习技术。它基于深度学习,使用神经网络来预测场景中每个点的颜色和密度,从而生成高质量的三维重建结果。
Nerf 通过训练神经网络从不同角度的图像中学习场景的表面和光照特征,然后使用学习到的信息来生成新的视角的图像。与传统的三维重建方法不同,Nerf 不需要对场景进行显式的几何建模,也不需要使用多张图像进行立体匹配。相反,它使用单张图像和相机参数来训练神经网络,从而生成高质量的三维重建结果。
Nerf 技术已经被广泛应用于虚拟现实、增强现实和电影等领域,可以生成逼真的三维场景和高质量的图像。同时,它也是当前计算机视觉和深度学习领域的研究热点之一,引起了广泛的关注和研究。
train-nerf.py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import cv2
import os
import json
import argparse
import imageio
# 主要知识点
# 1. 位置编码,Positional Encoding
# - 对于输入的x、y、z坐标,因为是连续的无法进行区分,因此采用ff特征,即傅立叶特征进行编码
# - 编码为cos、sin不同频率的叠加,使得连续值可以具有足够的区分性
# 2. 视图独立性,View Dependent
# - 输入不仅仅是光线采样点的x、y、z坐标,加上了视图依赖,即x、y、z、theta、pi,5d输入,此时多了射线所在视图
# 3. 分层采样,Hierarchical sampling
# - 将渲染分为两级,由于第一级别的模型是均匀采样,而实际会有很多无效的采样(即对颜色没有贡献的区域会占比太多),在模型
# 中看来,就是某些点的梯度为0,对模型训练没有贡献
# - 因此采用两级模型,model、fine,model模型使用均匀采样,推断后得到weights的分布,通过对weights分布进行重采样,使得采样点
# 更加集中在更重要的区域,今儿使得参与训练的点大都是有效的点。所以model作为一级推理,fine则推理重采样后的点
#
# x. 拓展,对于射线的方向和原点的理解,需要具有基本的3d变换知识,建议看GAMES101的前5章补充知识
# PSNR是峰值信噪比,表示重建的逼真程度
# 这三个环节有了,效果就会非常逼真,但是某些细节上还是存在不足。另外训练时间非常关键
class BlenderProvider:
def __init__(self, root, transforms_file, half_resolution=True):
self.meta = json.load(open(os.path.join(root, transforms_file), "r"))
self.root = root
self.frames = self.meta["frames"]
self.images = []
self.poses = []
self.camera_angle_x = self.meta["camera_angle_x"]
for frame in self.frames:
image_file = os.path.join(self.root, frame["file_path"] + ".png")
image = imageio.imread(image_file)
if half_resolution:
image = cv2.resize(image, dsize=None, fx=0.5, fy=0.5, interpolation=cv2.INTER_AREA)
self.images.append(image)
self.poses.append(frame["transform_matrix"])
self.poses = np.stack(self.poses)
self.images = (np.stack(self.images) / 255.0).astype(np.float32)
self.width = self.images.shape[2]
self.height = self.images.shape[1]
self.focal = 0.5 * self.width / np.tan(0.5 * self.camera_angle_x)
alpha = self.images[..., [3]]
rgb = self.images[..., :3]
self.images = rgb * alpha + (1 - alpha)
class NeRFDataset:
def __init__(self, provider, batch_size=1024, device="cpu"):
self.images = provider.images
self.poses = provider.poses
self.focal = provider.focal
self.width = provider.width
self.height = provider.height
self.batch_size = batch_size
self.num_image = len(self.images)
self.precrop_iters = 500
self.precrop_frac = 0.5
self.niter = 0
self.device = device
self.initialize()
def initialize(self):
warange = torch.arange(self.width, dtype=torch.float32, device=self.device)
harange = torch.arange(self.height, dtype=torch.float32, device=self.device)
y, x = torch.meshgrid(harange, warange)
self.transformed_x = (x - self.width * 0.5) / self.focal
self.transformed_y = (y - self.height * 0.5) / self.focal
# pre center crop
self.precrop_index = torch.arange(self.width * self.height).view(self.height, self.width)
dH = int(self.height // 2 * self.precrop_frac)
dW = int(self.width // 2 * self.precrop_frac)
self.precrop_index = self.precrop_index[
self.height // 2 - dH:self.height // 2 + dH,
self.width // 2 - dW:self.width // 2 + dW
].reshape(-1)
poses = torch.FloatTensor(self.poses, device=self.device)
all_ray_dirs, all_ray_origins = [], []
for i in range(len(self.images)):
ray_dirs, ray_origins = self.make_rays(self.transformed_x, self.transformed_y, poses[i])
all_ray_dirs.append(ray_dirs)
all_ray_origins.append(ray_origins)
self.all_ray_dirs = torch.stack(all_ray_dirs, dim=0)
self.all_ray_origins = torch.stack(all_ray_origins, dim=0)
self.images = torch.FloatTensor(self.images, device=self.device).view(self.num_image, -1, 3)
def __getitem__(self, index):
self.niter += 1
ray_dirs = self.all_ray_dirs[index]
ray_oris = self.all_ray_origins[index]
img_pixels = self.images[index]
if self.niter < self.precrop_iters:
ray_dirs = ray_dirs[self.precrop_index]
ray_oris = ray_oris[self.precrop_index]
img_pixels = img_pixels[self.precrop_index]
nrays = self.batch_size
select_inds = np.random.choice(ray_dirs.shape[0], size=[nrays], replace=False)
ray_dirs = ray_dirs[select_inds]
ray_oris = ray_oris[select_inds]
img_pixels = img_pixels[select_inds]
# dirs是指:direction
# ori是指: origin