地址:https://github.com/sxyu/pixel-nerf
东西太多了 懒得全搞
pixelNeRF: ReadME
Alex Yu, Vickie Ye, Matthew Tancik, Angjoo Kanazawa
UC Berkeley
arXiv: http://arxiv.org/abs/2012.02190
这是我们的论文pixelNeRF的官方存储库,等待最终发布。两个物体的实验仍然下落不明。还可以添加几个特征.
Environment setup
首先,我们更喜欢使用conda创建环境:
conda env create -f environment.yml
conda activate pixelnerf
environment.yml
# run: conda env create -f environment.yml
name: pixelnerf
channels:
- pytorch
- defaults
dependencies:
- python>=3.8
- pip
- pip:
- pyhocon
- opencv-python
- dotmap
- tensorboard
- imageio
- imageio-ffmpeg
- ipdb
- pretrainedmodels
- lpips
- scipy
- numpy
- matplotlib
- pytorch==1.6.0
- torchvision==0.7.0
- scikit-image==0.17.2
- tqdm
请确保您的NVIDIA驱动程序至少支持CUDA 10.2.
可选用 pip -r requirements.txt.
requirements.txt
torch
torchvision
pretrainedmodels
pyhocon
imageio
opencv-python
imageio-ffmpeg
tensorboard
dotmap
numpy
scipy
scikit-image
ipdb
matplotlib
tqdm
lpips
Getting the data
- 对于主要的ShapeNet实验,我们使用ShapeNet 64x64数据集 来自 NMR
https://s3.eu-central-1.amazonaws.com/avg-projects/differentiable_volumetric_rendering/data/NMR_Dataset.zip
(由DVR作者主持)
- 对于 novel-category generalization实验,自定义拆分是需要的。. 下载以下脚本:
https://drive.google.com/file/d/1Uxf0GguAUTSFIDD_7zuPbxk1C9WgXjce/view?usp=sharing
将所述文件放在NMR_Dataset
下 and 在 所述目录下runpython genlist.py
. This generates train/val/test lists for the
experiment.评估性能原因说明, test is only 1/4
of the unseen categories.
- The remaining datasets may be found in
https://drive.google.com/drive/folders/1PsT3uKwqHHD2bEEHkIXB99AlIjtmrEiR?usp=sharing
- Custom two-chair
multi_chair_{train/val/test}.zip
. 将splits下载到父目录中,并将父目录路径传递给training命令。-
- 要呈现自己的数据集,请随意使用我们的脚本
scripts/render_shapenet.py.
Seescripts/README.md
用于安装说明.
- 要呈现自己的数据集,请随意使用我们的脚本
-
- DTU (4x downsampled, rescaled) in DVR’s DTU format
dtu_dataset.zip
- SRN chair/car (128x128)
srn_*.zip
needed for single-category exps.
注意,汽车场景是由Vincent Sitzmann提供的重新渲染版本
虽然我们可以使用通用的数据格式,但我们选择将DTU和ShapeNet (NMR)数据集保留为DVR格式,将SRN数据保留为原始SRN格式。我们自己的双对象数据是NeRF格式的。数据适配器内置于代码中。
Running the model (video generation)
The main implementation is in the src/
directory, while evalutation scripts are in eval/.
First, download all pretrained weight files from https://drive.google.com/file/d/1UO_rL201guN6euoWkCOn-XpqR2e8o6ju/view?usp=sharing. 解压缩到 <project dir>/checkpoints/
, so that <project dir>/checkpoints/dtu/pixel_nerf_latest
exists.
ShapeNet Multiple Categories (NMR)
1.下载NMR ShapeNet渲染图(参见数据集部分,第一个链接)
2.Run using
python eval/gen_video.py -n sn64 --gpu_id <GPU(s)> --split test -P '2' -D <data_root>/NMR_Dataset -S 0
For unseen category generalization对于看不见的类别概括:
python eval/gen_video.py -n sn64_unseen --gpu_id=<GPU(s)> --split test -P '2' -D <data_root>/NMR_Dataset -S 0
Replace <GPU(s)>
with desired GPU id(s), 多个用空格分隔。
Replace -S 0
with -S <object_id>
to run on a different ShapeNet object id.
Replace -P '2'
with -P '<number>'
to use a different input view. Replace --split test
with --split train | val
to use different data split.
Append -R=20000
if running out of memory.
Result will be at visuals/sn64/videot<object_id>.mp4
or visuals/sn64_unseen/videot<object_id>.mp4.
该脚本还将打印路径
带有比较的所有ShapeNet对象的预生成结果可在以下位置找到https://www.ocf.berkeley.edu/~sxyu/ZG9yaWF0aA/pixelnerf/cross_v2/
ShapeNet Single-Category (SRN)
1.从数据集部分的Google drive文件夹中下载SRN汽车(或椅子)数据集. 解压到 <srn data dir>/cars_<train | test | val>
2.
python eval/gen_video.py -n srn_car --gpu_id=<GPU (s)> --split test -P '64 104' -D <srn data dir>/cars -S 1
Use -P 64
for 1-view (视图编号来自SRN).
椅子套装的情况与之类似(用椅子代替汽车)。在训练期间,我们的模型用每批随机的1/2视图进行训练。这似乎会降低性能,尤其是对于单视图。最好使用固定数量的视图。
DTU
确保你已经下载了上面的预训练权重。
1.从数据集部分的Google drive文件夹中下载DTU数据集。Extract to some directory, to get: <data_root>/rs_dtu_4
2.Run using
python eval/gen_video.py -n dtu --gpu_id=<GPU(s)> --split val -P '22 25 28' -D <data_root>/rs_dtu_4 -S 3 --scale 0.25
Replace <GPU(s)>
with desired GPU id(s).
Replace -S 3 with -S <scene_id> to run on a different scene.
这不是DTU的场景编号,而是val集合中的0-14.
Remove --scale 0.25
to 全分辨率渲染(相当慢).
Result will be at visuals/dtu/videov<scene_id>.mp4.
The script will also print the path.
注意,对于DTU,我只使用train/val集合,其中val用于测试。这是因为数据集非常小。在训练过程中,模型明显过拟合训练集。
Real Car Images
Note: requires PointRend from detectron2. Install detectron2 by following https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md.
确保你已经下载了上面的预训练权重。
1.Download any car image. Place it in <project dir>/input
. repo附带了一些示例图像. 汽车应该完全可见。
2.Run the 预处理程序脚本:
python scripts/preproc.py
This saves input/*_normalize.png.
如果结果不合理,PointRend就不起作用;请尝试另一个图像。
3.Run
python eval/eval_real.py
Outputs will be in <project dir>/output
斯坦福汽车数据集包含许多示例汽车图像: https://ai.stanford.edu/~jkrause/cars/car_dataset.html.
请注意,与本文相比,规格化试探法略有修改。可能会有一些小的不同。 You can pass -e -20
to eval_real.py
to 在生成的视频中将仰角设置得更高。
Overview of flags
通常,项目中的所有脚本都带有以下标志
请参考下表,了解所提供的带有相关配置和数据文件的实验列表:
Quantitative evaluation instructions定量评估说明
略
。。。
train.py
将src目录 插入系统路径
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
)
这样 from render import NeRFRenderer
就不用写成from src.render import NeRFRenderer
import
import warnings
import trainlib
from model import make_model, loss
from render import NeRFRenderer
from data import get_split_dataset
import util
import numpy as np
import torch.nn.functional as F
import torch
from dotmap import DotMap
config
args, conf = util.args.parse_args(extra_args, training=True, default_ray_batch_size=128)
device = util.get_cuda(args.gpu_id[0])
def extra_args(parser):
parser.add_argument(
"--batch_size", "-B", type=int, default=4, help="Object batch size ('SB')"
)
parser.add_argument(
"--nviews",
"-V",
type=str,
default="1",
help="Number of source views (multiview); put multiple (space delim) to pick randomly per batch ('NV')",
) #源视图的数量(多视图);将多个(空格 分割)放在每批中随机选取(“NV”)
parser.add_argument(
"--freeze_enc",
action="store_true",
default=None,
help="Freeze encoder weights and only train MLP",
)
parser.add_argument(
"--no_bbox_step",
type=int,
default=100000,
help="Step to stop using bbox sampling",
)
parser.add_argument(
"--fixed_test",
action="store_true",
default=None,
help="Freeze encoder weights and only train MLP",
)
return parser
def parse_args(
callback=None,
training=False,
default_conf="conf/default_mv.conf",
default_expname="example",
default_data_format="dvr",
default_num_epochs=10000000,
default_lr=1e-4,
default_gamma=1.00,
default_datadir="data",
default_ray_batch_size=50000,
):
parser = argparse.ArgumentParser()
parser.add_argument("--conf", "-c", type=str, default=None)
parser.add_argument("--resume", "-r", action="store_true", help="continue training")
parser.add_argument(
"--gpu_id", type=str, default="0", help="GPU(s) to use, space delimited"
)
parser.add_argument(
"--name", "-n", type=str, default=default_expname, help="experiment name"
)
parser.add_argument(
"--dataset_format",
"-F",
type=str,
default=None,
help="Dataset format, multi_obj | dvr | dvr_gen | dvr_dtu | srn",
)
parser.add_argument(
"--exp_group_name",
"-G",
type=str,
default=None,
help="if we want to group some experiments together",
)
parser.add_argument(
"--logs_path", type=str, default="logs", help="logs output directory",
)
parser.add_argument(
"--checkpoints_path",
type=str,
default="checkpoints",
help="checkpoints output directory",
)
parser.add_argument(
"--visual_path",
type=str,
default="visuals",
help="visualization output directory",
)
parser.add_argument(
"--epochs",
type=int,
default=default_num_epochs,
help="number of epochs to train for",
)
parser.add_argument("--lr", type=float, default=default_lr, help="learning rate")
parser.add_argument(
"--gamma", type=float, default=default_gamma, help="learning rate decay factor"
)
parser.add_argument(
"--datadir", "-D", type=str, default=None, help="Dataset directory"
)
parser.add_argument(
"--ray_batch_size", "-R", type=int, default=default_ray_batch_size, help="Ray batch size"
)
if callback is not None:
parser = callback(parser)
args = parser.parse_args()
if args.exp_group_name is not None:
args.logs_path = os.path.join(args.logs_path, args.exp_group_name)
args.checkpoints_path = os.path.join(args.checkpoints_path, args.exp_group_name)
args.visual_path = os.path.join(args.visual_path, args.exp_group_name)
os.makedirs(os.path.join(args.checkpoints_path, args.name), exist_ok=True)
os.makedirs(os.path.join(args.visual_path, args.name), exist_ok=True)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
EXPCONF_PATH = os.path.join(PROJECT_ROOT, "expconf.conf")
expconf = ConfigFactory.parse_file(EXPCONF_PATH)
if args.conf is None:
args.conf = expconf.get_string("config." + args.name, default_conf)
if args.conf is None:
args.conf = expconf.get_string("config." + args.name, default_conf)
if args.datadir is None:
args.datadir = expconf.get_string("datadir." + args.name, default_datadir)
conf = ConfigFactory.parse_file(args.conf)
if args.dataset_format is None:
args.dataset_format = conf.get_string("data.format", default_data_format)
args.gpu_id = list(map(int, args.gpu_id.split()))
print("EXPERIMENT NAME:", args.name)
if training:
print("CONTINUE?", "yes" if args.resume else "no")
print("* Config file:", args.conf)
print("* Dataset format:", args.dataset_format)
print("* Dataset location:", args.datadir)
return args, conf
default.conf
# Single-view only base model
# (Not used in experiments; resnet_fine_mv.conf 继承)
model {
# Condition on local encoder
use_encoder = True
# Condition also on a global encoder?
use_global_encoder = False
# Use xyz input instead of just z
# (didn't ablate消融)
use_xyz = True
# Canonical space xyz (default view space)规范空间xyz(默认视图空间)
canon_xyz = False
# Positional encoding
use_code = True
code {
num_freqs = 6
freq_factor = 1.5
include_input = True
}
# View directions
use_viewdirs = True
# Apply pos. enc. to viewdirs?
use_code_viewdirs = False
# MLP architecture
mlp_coarse {
type = resnet # Can change to mlp
n_blocks = 3
d_hidden = 512
}
mlp_fine {
type = resnet
n_blocks = 3
d_hidden = 512
}
# Encoder architecture
encoder {
backbone = resnet34
pretrained = True
num_layers = 4
}
}
renderer {
n_coarse = 64
n_fine = 32
# Try using expected depth sample
n_fine_depth = 16
# Noise to add to depth sample
depth_std = 0.01
# Decay schedule, not used
sched = []
# White background color (false : black)
white_bkgd = True
}
loss {
# RGB losses coarse/fine
rgb {
use_l1 = False
}
rgb_fine {
use_l1 = False
}
# Alpha regularization (在最终版本中禁用)
alpha {
# lambda_alpha = 0.0001
lambda_alpha = 0.0
clamp_alpha = 100
init_epoch = 5
}
# Coarse/fine weighting (nerf = equal)
lambda_coarse = 1.0 # loss = lambda_coarse * loss_coarse + loss_fine
lambda_fine = 1.0 # loss = lambda_coarse * loss_coarse + loss_fine
}
train {
# Training
print_interval = 2
save_interval = 50
vis_interval = 100
eval_interval = 50
# 累积梯度。不太推荐。
# 1 = disable
accu_grad = 1
# Number of times to repeat dataset per 'epoch'
# 如果数据集非常小,如DTU,则非常有用
num_epoch_repeats = 1
}
dataset
dset, val_dset, _ = get_split_dataset(args.dataset_format, args.datadir)
print(
"dset z_near {}, z_far {}, lindisp {}".format(dset.z_near, dset.z_far, dset.lindisp)
)
def get_split_dataset(dataset_type, datadir, want_split="all", training=True, **kwargs):
"""
检索到所需的数据集类
:param dataset_type 数据集类型名称 (srn|dvr|dvr_gen, etc)
:param datadir root 数据集的目录名. For SRN/multi_obj data:
if data is in dir/cars_train, dir/cars_test, ... then put dir/cars
:param want_split 数据集的根目录名称
:param training 在评估脚本中设置为False
"""
dset_class, train_aug = None, None
flags, train_aug_flags = {}, {}
if dataset_type == "srn":
# For ShapeNet single-category (from SRN)
dset_class = SRNDataset
elif dataset_type == "multi_obj":
# For multiple-object
dset_class = MultiObjectDataset
elif dataset_type.startswith("dvr"):
# For ShapeNet 64x64
dset_class = DVRDataset
if dataset_type == "dvr_gen":
# For generalization training (train some categories, eval on others)
flags["list_prefix"] = "gen_"
elif dataset_type == "dvr_dtu":
# DTU dataset
flags["list_prefix"] = "new_"
if training:
flags["max_imgs"] = 49
flags["sub_format"] = "dtu"
flags["scale_focal"] = False
flags["z_near"] = 0.1
flags["z_far"] = 5.0
# Apply color jitter during train
train_aug = ColorJitterDataset
train_aug_flags = {"extra_inherit_attrs": ["sub_format"]}
else:
raise NotImplementedError("Unsupported dataset type", dataset_type)
want_train = want_split != "val" and want_split != "test"
want_val = want_split != "train" and want_split != "test"
want_test = want_split != "train" and want_split != "val"
if want_train:
train_set = dset_class(datadir, stage="train", **flags, **kwargs)
if train_aug is not None:
train_set = train_aug(train_set, **train_aug_flags)
if want_val:
val_set = dset_class(datadir, stage="val", **flags, **kwargs)
if want_test:
test_set = dset_class(datadir, stage="test", **flags, **kwargs)
if want_split == "train":
return train_set
elif want_split == "val":
return val_set
elif want_split == "test":
return test_set
return train_set, val_set, test_set
class SRNDataset(torch.utils.data.Dataset):
"""
Dataset from SRN (V. Sitzmann et al. 2020)
"""
def __init__(
self, path, stage="train", image_size=(128, 128), world_scale=1.0,
):
"""
:param stage train | val | test
:param image_size result image size (resizes if different)
:param world_scale amount to scale entire world 相当于整个世界 by
"""
super().__init__()
self.base_path = path + "_" + stage
self.dataset_name = os.path.basename(path)
print("Loading SRN dataset", self.base_path, "name:", self.dataset_name)
self.stage = stage
assert os.path.exists(self.base_path)
is_chair = "chair" in self.dataset_name
if is_chair and stage == "train":
# Ugly thing from SRN's public dataset 来自SRN公共数据集的丑东西
tmp = os.path.join(self.base_path, "chairs_2.0_train")
if os.path.exists(tmp):
self.base_path = tmp
self.intrins = sorted(
glob.glob(os.path.join(self.base_path, "*", "intrinsics.txt"))
)
self.image_to_tensor = get_image_to_tensor_balanced()
self.mask_to_tensor = get_mask_to_tensor()
self.image_size = image_size
self.world_scale = world_scale
self._coord_trans = torch.diag(
torch.tensor([1, -1, -1, 1], dtype=torch.float32)
)
if is_chair:
self.z_near = 1.25
self.z_far = 2.75
else:
self.z_near = 0.8
self.z_far = 1.8
self.lindisp = False
def __len__(self):
return len(self.intrins)
def __getitem__(self, index):
intrin_path = self.intrins[index]
dir_path = os.path.dirname(intrin_path)
rgb_paths = sorted(glob.glob(os.path.join(dir_path, "rgb", "*")))
pose_paths = sorted(glob.glob(os.path.join(dir_path, "pose", "*")))
assert len(rgb_paths) == len(pose_paths)
with open(intrin_path, "r") as intrinfile:
lines = intrinfile.readlines()
focal, cx, cy, _ = map(float, lines[0].split())
height, width = map(int, lines[-1].split())
all_imgs = []
all_poses = []
all_masks = []
all_bboxes = []
for rgb_path, pose_path in zip(rgb_paths, pose_paths):
img = imageio.imread(rgb_path)[..., :3]
img_tensor = self.image_to_tensor(img)
mask = (img != 255).all(axis=-1)[..., None].astype(np.uint8) * 255
mask_tensor = self.mask_to_tensor(mask)
pose = torch.from_numpy(
np.loadtxt(pose_path, dtype=np.float32).reshape(4, 4)
)
pose = pose @ self._coord_trans
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
rnz = np.where(rows)[0]
cnz = np.where(cols)[0]
if len(rnz) == 0:
raise RuntimeError(
"ERROR: Bad image at", rgb_path, "please investigate!"
)
rmin, rmax = rnz[[0, -1]]
cmin, cmax = cnz[[0, -1]]
bbox = torch.tensor([cmin, rmin, cmax, rmax], dtype=torch.float32)
all_imgs.append(img_tensor)
all_masks.append(mask_tensor)
all_poses.append(pose)
all_bboxes.append(bbox)
all_imgs = torch.stack(all_imgs)
all_poses = torch.stack(all_poses)
all_masks = torch.stack(all_masks)
all_bboxes = torch.stack(all_bboxes)
if all_imgs.shape[-2:] != self.image_size:
scale = self.image_size[0] / all_imgs.shape[-2]
focal *= scale
cx *= scale
cy *= scale
all_bboxes *= scale
all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area")
all_masks = F.interpolate(all_masks, size=self.image_size, mode="area")
if self.world_scale != 1.0:
focal *= self.world_scale
all_poses[:, :3, 3] *= self.world_scale
focal = torch.tensor(focal, dtype=torch.float32)
result = {
"path": dir_path,
"img_id": index,
"focal": focal,
"c": torch.tensor([cx, cy], dtype=torch.float32),
"images": all_imgs,
"masks": all_masks,
"bbox": all_bboxes,
"poses": all_poses,
}
return result
net
net = make_model(conf["model"]).to(device=device)
net.stop_encoder_grad = args.freeze_enc
if args.freeze_enc:
print("Encoder frozen")
net.encoder.eval()
def make_model(conf, *args, **kwargs):
""" 允许更多模型类型的占位符 """
model_type = conf.get_string("type", "pixelnerf") # single
if model_type == "pixelnerf":
net = PixelNeRFNet(conf, *args, **kwargs)
else:
raise NotImplementedError("Unsupported model type", model_type)
return net
class PixelNeRFNet
class PixelNeRFNet(torch.nn.Module):
def __init__(self, conf, stop_encoder_grad=False):
"""
:param conf PyHocon config subtree 'model'
"""
super().__init__()
self.encoder = make_encoder(conf["encoder"])
self.use_encoder = conf.get_bool("use_encoder", True) # Image features?
self.use_xyz = conf.get_bool("use_xyz", False)
assert self.use_encoder or self.use_xyz # Must use some feature..
# 是否移动z以在规范框架中对齐。
# 这样所有的物体,不管相机到中心的距离, 将集中在z=0.
# 仅在ShapeNet类型设置中有意义。
self.normalize_z = conf.get_bool("normalize_z", True)
self.stop_encoder_grad = (
stop_encoder_grad # Stop ConvNet gradient (freeze weights)
)
self.use_code = conf.get_bool("use_code", False) # Positional encoding
self.use_code_viewdirs = conf.get_bool(
"use_code_viewdirs", True
) # Positional encoding applies to viewdirs
# Enable view directions
self.use_viewdirs = conf.get_bool("use_viewdirs", False)
# Global image features?
self.use_global_encoder = conf.get_bool("use_global_encoder", False)
d_latent = self.encoder.latent_size if self.use_encoder else 0
d_in = 3 if self.use_xyz else 1
if self.use_viewdirs and self.use_code_viewdirs:
# Apply positional encoding to viewdirs
d_in += 3
if self.use_code and d_in > 0:
# Positional encoding for x,y,z OR view z
self.code = PositionalEncoding.from_conf(conf["code"], d_in=d_in)
d_in = self.code.d_out
if self.use_viewdirs and not self.use_code_viewdirs:
# Don't apply positional encoding to viewdirs (concat after encoded)
d_in += 3
if self.use_global_encoder:
# Global image feature
self.global_encoder = ImageEncoder.from_conf(conf["global_encoder"])
self.global_latent_size = self.global_encoder.latent_size
d_latent += self.global_latent_size
d_out = 4
self.latent_size = self.encoder.latent_size
self.mlp_coarse = make_mlp(conf["mlp_coarse"], d_in, d_latent, d_out=d_out)
self.mlp_fine = make_mlp(
conf["mlp_fine"], d_in, d_latent, d_out=d_out, allow_empty=True
)
# Note: this is world -> camera, and bottom row is omitted并且底部一行被省略
self.register_buffer("poses", torch.empty(1, 3, 4), persistent=False)
self.register_buffer("image_shape", torch.empty(2), persistent=False)
self.d_in = d_in
self.d_out = d_out
self.d_latent = d_latent
self.register_buffer("focal", torch.empty(1, 2), persistent=False)
# Principal point
self.register_buffer("c", torch.empty(1, 2), persistent=False)
self.num_objs = 0
self.num_views_per_obj = 1
def encode(self, images, poses, focal, z_bounds=None, c=None):
"""
:param images (NS, 3, H, W)
NS is number of input (aka source or reference) views
:param poses (NS, 4, 4)
:param focal focal length () or (2) or (NS) or (NS, 2) [fx, fy]
:param z_bounds ignored argument (used in the past)
:param c principal point None or () or (2) or (NS) or (NS, 2) [cx, cy],
default is center of image
"""
self.num_objs = images.size(0)
if len(images.shape) == 5:
assert len(poses.shape) == 4
assert poses.size(1) == images.size(
1
) # Be consistent with NS = num input views
self.num_views_per_obj = images.size(1)
images = images.reshape(-1, *images.shape[2:])
poses = poses.reshape(-1, 4, 4)
else:
self.num_views_per_obj = 1
self.encoder(images)
rot = poses[:, :3, :3].transpose(1, 2) # (B, 3, 3)
trans = -torch.bmm(rot, poses[:, :3, 3:]) # (B, 3, 1)
self.poses = torch.cat((rot, trans), dim=-1) # (B, 3, 4)
self.image_shape[0] = images.shape[-1]
self.image_shape[1] = images.shape[-2]
# Handle various focal length/principal point formats
if len(focal.shape) == 0:
# Scalar: fx = fy = value for all views
focal = focal[None, None].repeat((1, 2))
elif len(focal.shape) == 1:
# Vector f: fx = fy = f_i *for view i*
# Length should match NS (or 1 for broadcast)
focal = focal.unsqueeze(-1).repeat((1, 2))
else:
focal = focal.clone()
self.focal = focal.float()
self.focal[..., 1] *= -1.0
if c is None:
# Default principal point is center of image
c = (self.image_shape * 0.5).unsqueeze(0)
elif len(c.shape) == 0:
# Scalar: cx = cy = value for all views
c = c[None, None].repeat((1, 2))
elif len(c.shape) == 1:
# Vector c: cx = cy = c_i *for view i*
c = c.unsqueeze(-1).repeat((1, 2))
self.c = c
if self.use_global_encoder:
self.global_encoder(images)
def forward(self, xyz, coarse=True, viewdirs=None, far=False):
"""
Predict (r, g, b, sigma) at world space points xyz.
Please call encode first!
:param xyz (SB, B, 3)
SB is batch of objects
B is batch of points (in rays)
NS is number of input views
:return (SB, B, 4) r g b sigma
"""
with profiler.record_function("model_inference"):
SB, B, _ = xyz.shape
NS = self.num_views_per_obj
# Transform query points into the camera spaces of the input views
xyz = repeat_interleave(xyz, NS) # (SB*NS, B, 3)
xyz_rot = torch.matmul(self.poses[:, None, :3, :3], xyz.unsqueeze(-1))[
..., 0
]
xyz = xyz_rot + self.poses[:, None, :3, 3]
if self.d_in > 0:
# * Encode the xyz coordinates
if self.use_xyz:
if self.normalize_z:
z_feature = xyz_rot.reshape(-1, 3) # (SB*B, 3)
else:
z_feature = xyz.reshape(-1, 3) # (SB*B, 3)
else:
if self.normalize_z:
z_feature = -xyz_rot[..., 2].reshape(-1, 1) # (SB*B, 1)
else:
z_feature = -xyz[..., 2].reshape(-1, 1) # (SB*B, 1)
if self.use_code and not self.use_code_viewdirs:
# Positional encoding (no viewdirs)
z_feature = self.code(z_feature)
if self.use_viewdirs:
# * Encode the view directions
assert viewdirs is not None
# Viewdirs to input view space
viewdirs = viewdirs.reshape(SB, B, 3, 1)
viewdirs = repeat_interleave(viewdirs, NS) # (SB*NS, B, 3, 1)
viewdirs = torch.matmul(
self.poses[:, None, :3, :3], viewdirs
) # (SB*NS, B, 3, 1)
viewdirs = viewdirs.reshape(-1, 3) # (SB*B, 3)
z_feature = torch.cat(
(z_feature, viewdirs), dim=1
) # (SB*B, 4 or 6)
if self.use_code and self.use_code_viewdirs:
# Positional encoding (with viewdirs)
z_feature = self.code(z_feature)
mlp_input = z_feature
if self.use_encoder:
# Grab encoder's latent code.
uv = -xyz[:, :, :2] / xyz[:, :, 2:] # (SB, B, 2)
uv *= repeat_interleave(
self.focal.unsqueeze(1), NS if self.focal.shape[0] > 1 else 1
)
uv += repeat_interleave(
self.c.unsqueeze(1), NS if self.c.shape[0] > 1 else 1
) # (SB*NS, B, 2)
latent = self.encoder.index(
uv, None, self.image_shape
) # (SB * NS, latent, B)
if self.stop_encoder_grad:
latent = latent.detach()
latent = latent.transpose(1, 2).reshape(
-1, self.latent_size
) # (SB * NS * B, latent)
if self.d_in == 0:
# z_feature not needed
mlp_input = latent
else:
mlp_input = torch.cat((latent, z_feature), dim=-1)
if self.use_global_encoder:
# Concat global latent code if enabled
global_latent = self.global_encoder.latent
assert mlp_input.shape[0] % global_latent.shape[0] == 0
num_repeats = mlp_input.shape[0] // global_latent.shape[0]
global_latent = repeat_interleave(global_latent, num_repeats)
mlp_input = torch.cat((global_latent, mlp_input), dim=-1)
# Camera frustum culling stuff, currently disabled
combine_index = None
dim_size = None
# Run main NeRF network
if coarse or self.mlp_fine is None:
mlp_output = self.mlp_coarse(
mlp_input,
combine_inner_dims=(self.num_views_per_obj, B),
combine_index=combine_index,
dim_size=dim_size,
)
else:
mlp_output = self.mlp_fine(
mlp_input,
combine_inner_dims=(self.num_views_per_obj, B),
combine_index=combine_index,
dim_size=dim_size,
)
# Interpret the output
mlp_output = mlp_output.reshape(-1, B, self.d_out)
rgb = mlp_output[..., :3]
sigma = mlp_output[..., 3:4]
output_list = [torch.sigmoid(rgb), torch.relu(sigma)]
output = torch.cat(output_list, dim=-1)
output = output.reshape(SB, B, -1)
return output
def load_weights(self, args, opt_init=False, strict=True, device=None):
"""
Helper for loading weights according to argparse arguments.
Your can put a checkpoint at checkpoints/<exp>/pixel_nerf_init to use as initialization.
:param opt_init if true, loads from init checkpoint instead of usual even when resuming
"""
# TODO: make backups
if opt_init and not args.resume:
return
ckpt_name = (
"pixel_nerf_init" if opt_init or not args.resume else "pixel_nerf_latest"
)
model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, ckpt_name)
if device is None:
device = self.poses.device
if os.path.exists(model_path):
print("Load", model_path)
self.load_state_dict(
torch.load(model_path, map_location=device), strict=strict
)
elif not opt_init:
warnings.warn(
(
"WARNING: {} does not exist, not loaded!! Model will be re-initialized.\n"
+ "If you are trying to load a pretrained model, STOP since it's not in the right place. "
+ "If training, unless you are startin a new experiment, please remember to pass --resume."
).format(model_path)
)
return self
def save_weights(self, args, opt_init=False):
"""
Helper for saving weights according to argparse arguments
:param opt_init if true, saves from init checkpoint instead of usual
"""
from shutil import copyfile
ckpt_name = "pixel_nerf_init" if opt_init else "pixel_nerf_latest"
backup_name = "pixel_nerf_init_backup" if opt_init else "pixel_nerf_backup"
ckpt_path = osp.join(args.checkpoints_path, args.name, ckpt_name)
ckpt_backup_path = osp.join(args.checkpoints_path, args.name, backup_name)
if osp.exists(ckpt_path):
copyfile(ckpt_path, ckpt_backup_path)
torch.save(self.state_dict(), ckpt_path)
return self
renderer
renderer = NeRFRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,).to(
device=device
)
# Parallize
render_par = renderer.bind_parallel(net, args.gpu_id).eval()
nviews = list(map(int, args.nviews.split()))
class NeRFRenderer
class NeRFRenderer(torch.nn.Module):
"""
NeRF differentiable renderer
:param n_coarse number of coarse (binned uniform) samples
:param n_fine number of fine (importance) samples
:param n_fine_depth number of expected depth samples
:param noise_std noise to add to sigma. We do not use it
:param depth_std noise for depth samples
:param eval_batch_size ray batch size for evaluation
:param white_bkgd if true, background color is white; else black
:param lindisp if to use samples linear in disparity instead of distance
:param sched ray sampling schedule. list containing 3 lists of equal length.
sched[0] is list of iteration numbers,
sched[1] is list of coarse sample numbers,
sched[2] is list of fine sample numbers
"""
def __init__(
self,
n_coarse=128,
n_fine=0,
n_fine_depth=0,
noise_std=0.0,
depth_std=0.01,
eval_batch_size=100000,
white_bkgd=False,
lindisp=False,
sched=None, # ray sampling schedule for coarse and fine rays
):
super().__init__()
self.n_coarse = n_coarse
self.n_fine = n_fine
self.n_fine_depth = n_fine_depth
self.noise_std = noise_std
self.depth_std = depth_std
self.eval_batch_size = eval_batch_size
self.white_bkgd = white_bkgd
self.lindisp = lindisp
if lindisp:
print("Using linear displacement rays")
self.using_fine = n_fine > 0
self.sched = sched
if sched is not None and len(sched) == 0:
self.sched = None
self.register_buffer(
"iter_idx", torch.tensor(0, dtype=torch.long), persistent=True
)
self.register_buffer(
"last_sched", torch.tensor(0, dtype=torch.long), persistent=True
)
def sample_coarse(self, rays):
"""
Stratified sampling. Note this is different from original NeRF slightly.
:param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
:return (B, Kc)
"""
device = rays.device
near, far = rays[:, -2:-1], rays[:, -1:] # (B, 1)
step = 1.0 / self.n_coarse
B = rays.shape[0]
z_steps = torch.linspace(0, 1 - step, self.n_coarse, device=device) # (Kc)
z_steps = z_steps.unsqueeze(0).repeat(B, 1) # (B, Kc)
z_steps += torch.rand_like(z_steps) * step
if not self.lindisp: # Use linear sampling in depth space
return near * (1 - z_steps) + far * z_steps # (B, Kf)
else: # Use linear sampling in disparity space
return 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf)
# Use linear sampling in depth space
return near * (1 - z_steps) + far * z_steps # (B, Kc)
def sample_fine(self, rays, weights):
"""
Weighted stratified (importance) sample
:param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
:param weights (B, Kc)
:return (B, Kf-Kfd)
"""
device = rays.device
B = rays.shape[0]
weights = weights.detach() + 1e-5 # Prevent division by zero
pdf = weights / torch.sum(weights, -1, keepdim=True) # (B, Kc)
cdf = torch.cumsum(pdf, -1) # (B, Kc)
cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (B, Kc+1)
u = torch.rand(
B, self.n_fine - self.n_fine_depth, dtype=torch.float32, device=device
) # (B, Kf)
inds = torch.searchsorted(cdf, u, right=True).float() - 1.0 # (B, Kf)
inds = torch.clamp_min(inds, 0.0)
z_steps = (inds + torch.rand_like(inds)) / self.n_coarse # (B, Kf)
near, far = rays[:, -2:-1], rays[:, -1:] # (B, 1)
if not self.lindisp: # Use linear sampling in depth space
z_samp = near * (1 - z_steps) + far * z_steps # (B, Kf)
else: # Use linear sampling in disparity space
z_samp = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf)
return z_samp
def sample_fine_depth(self, rays, depth):
"""
Sample around specified depth
:param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
:param depth (B)
:return (B, Kfd)
"""
z_samp = depth.unsqueeze(1).repeat((1, self.n_fine_depth))
z_samp += torch.randn_like(z_samp) * self.depth_std
# Clamp does not support tensor bounds
z_samp = torch.max(torch.min(z_samp, rays[:, -1:]), rays[:, -2:-1])
return z_samp
def composite(self, model, rays, z_samp, coarse=True, sb=0):
"""
Render RGB and depth for each ray using NeRF alpha-compositing formula,
given sampled positions along each ray (see sample_*)
:param model should return (B, (r, g, b, sigma)) when called with (B, (x, y, z))
should also support 'coarse' boolean argument
:param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
:param z_samp z positions sampled for each ray (B, K)
:param coarse whether to evaluate using coarse NeRF
:param sb super-batch dimension; 0 = disable
:return weights (B, K), rgb (B, 3), depth (B)
"""
with profiler.record_function("renderer_composite"):
B, K = z_samp.shape
deltas = z_samp[:, 1:] - z_samp[:, :-1] # (B, K-1)
# if far:
# delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # infty (B, 1)
delta_inf = rays[:, -1:] - z_samp[:, -1:]
deltas = torch.cat([deltas, delta_inf], -1) # (B, K)
# (B, K, 3)
points = rays[:, None, :3] + z_samp.unsqueeze(2) * rays[:, None, 3:6]
points = points.reshape(-1, 3) # (B*K, 3)
use_viewdirs = hasattr(model, "use_viewdirs") and model.use_viewdirs
val_all = []
if sb > 0:
points = points.reshape(
sb, -1, 3
) # (SB, B'*K, 3) B' is real ray batch size
eval_batch_size = (self.eval_batch_size - 1) // sb + 1
eval_batch_dim = 1
else:
eval_batch_size = self.eval_batch_size
eval_batch_dim = 0
split_points = torch.split(points, eval_batch_size, dim=eval_batch_dim)
if use_viewdirs:
dim1 = K
viewdirs = rays[:, None, 3:6].expand(-1, dim1, -1) # (B, K, 3)
if sb > 0:
viewdirs = viewdirs.reshape(sb, -1, 3) # (SB, B'*K, 3)
else:
viewdirs = viewdirs.reshape(-1, 3) # (B*K, 3)
split_viewdirs = torch.split(
viewdirs, eval_batch_size, dim=eval_batch_dim
)
for pnts, dirs in zip(split_points, split_viewdirs):
val_all.append(model(pnts, coarse=coarse, viewdirs=dirs))
else:
for pnts in split_points:
val_all.append(model(pnts, coarse=coarse))
points = None
viewdirs = None
# (B*K, 4) OR (SB, B'*K, 4)
out = torch.cat(val_all, dim=eval_batch_dim)
out = out.reshape(B, K, -1) # (B, K, 4 or 5)
rgbs = out[..., :3] # (B, K, 3)
sigmas = out[..., 3] # (B, K)
if self.training and self.noise_std > 0.0:
sigmas = sigmas + torch.randn_like(sigmas) * self.noise_std
alphas = 1 - torch.exp(-deltas * torch.relu(sigmas)) # (B, K)
deltas = None
sigmas = None
alphas_shifted = torch.cat(
[torch.ones_like(alphas[:, :1]), 1 - alphas + 1e-10], -1
) # (B, K+1) = [1, a1, a2, ...]
T = torch.cumprod(alphas_shifted, -1) # (B)
weights = alphas * T[:, :-1] # (B, K)
alphas = None
alphas_shifted = None
rgb_final = torch.sum(weights.unsqueeze(-1) * rgbs, -2) # (B, 3)
depth_final = torch.sum(weights * z_samp, -1) # (B)
if self.white_bkgd:
# White background
pix_alpha = weights.sum(dim=1) # (B), pixel alpha
rgb_final = rgb_final + 1 - pix_alpha.unsqueeze(-1) # (B, 3)
return (
weights,
rgb_final,
depth_final,
)
def forward(
self, model, rays, want_weights=False,
):
"""
:model nerf model, should return (SB, B, (r, g, b, sigma))
when called with (SB, B, (x, y, z)), for multi-object:
SB = 'super-batch' = size of object batch,
B = size of per-object ray batch.
Should also support 'coarse' boolean argument for coarse NeRF.
:param rays ray spec [origins (3), directions (3), near (1), far (1)] (SB, B, 8)
:param want_weights if true, returns compositing weights (SB, B, K)
:return render dict
"""
with profiler.record_function("renderer_forward"):
if self.sched is not None and self.last_sched.item() > 0:
self.n_coarse = self.sched[1][self.last_sched.item() - 1]
self.n_fine = self.sched[2][self.last_sched.item() - 1]
assert len(rays.shape) == 3
superbatch_size = rays.shape[0]
rays = rays.reshape(-1, 8) # (SB * B, 8)
z_coarse = self.sample_coarse(rays) # (B, Kc)
coarse_composite = self.composite(
model, rays, z_coarse, coarse=True, sb=superbatch_size,
)
outputs = DotMap(
coarse=self._format_outputs(
coarse_composite, superbatch_size, want_weights=want_weights,
),
)
if self.using_fine:
all_samps = [z_coarse]
if self.n_fine - self.n_fine_depth > 0:
all_samps.append(
self.sample_fine(rays, coarse_composite[0].detach())
) # (B, Kf - Kfd)
if self.n_fine_depth > 0:
all_samps.append(
self.sample_fine_depth(rays, coarse_composite[2])
) # (B, Kfd)
z_combine = torch.cat(all_samps, dim=-1) # (B, Kc + Kf)
z_combine_sorted, argsort = torch.sort(z_combine, dim=-1)
fine_composite = self.composite(
model, rays, z_combine_sorted, coarse=False, sb=superbatch_size,
)
outputs.fine = self._format_outputs(
fine_composite, superbatch_size, want_weights=want_weights,
)
return outputs
def _format_outputs(
self, rendered_outputs, superbatch_size, want_weights=False,
):
weights, rgb, depth = rendered_outputs
if superbatch_size > 0:
rgb = rgb.reshape(superbatch_size, -1, 3)
depth = depth.reshape(superbatch_size, -1)
weights = weights.reshape(superbatch_size, -1, weights.shape[-1])
ret_dict = DotMap(rgb=rgb, depth=depth)
if want_weights:
ret_dict.weights = weights
return ret_dict
def sched_step(self, steps=1):
"""
Called each training iteration to update sample numbers
according to schedule
"""
if self.sched is None:
return
self.iter_idx += steps
while (
self.last_sched.item() < len(self.sched[0])
and self.iter_idx.item() >= self.sched[0][self.last_sched.item()]
):
self.n_coarse = self.sched[1][self.last_sched.item()]
self.n_fine = self.sched[2][self.last_sched.item()]
print(
"INFO: NeRF sampling resolution changed on schedule ==> c",
self.n_coarse,
"f",
self.n_fine,
)
self.last_sched += 1
@classmethod
def from_conf(cls, conf, white_bkgd=False, lindisp=False, eval_batch_size=100000):
return cls(
conf.get_int("n_coarse", 128),
conf.get_int("n_fine", 0),
n_fine_depth=conf.get_int("n_fine_depth", 0),
noise_std=conf.get_float("noise_std", 0.0),
depth_std=conf.get_float("depth_std", 0.01),
white_bkgd=conf.get_float("white_bkgd", white_bkgd),
lindisp=lindisp,
eval_batch_size=conf.get_int("eval_batch_size", eval_batch_size),
sched=conf.get_list("sched", None),
)
def bind_parallel(self, net, gpus=None, simple_output=False):
"""
Returns a wrapper module compatible with DataParallel.
Specifically, it renders rays with this renderer
but always using the given network instance.
Specify a list of GPU ids in 'gpus' to apply DataParallel automatically.
:param net A PixelNeRF network
:param gpus list of GPU ids to parallize to. If length is 1,
does not parallelize
:param simple_output only returns rendered (rgb, depth) instead of the
full render output map. Saves data tranfer cost.
:return torch module
"""
wrapped = _RenderWrapper(net, self, simple_output=simple_output)
if gpus is not None and len(gpus) > 1:
print("Using multi-GPU", gpus)
wrapped = torch.nn.DataParallel(wrapped, gpus, dim=1)
return wrapped
train
trainer = PixelNeRFTrainer()
trainer.start()
class PixelNeRFTrainer
class PixelNeRFTrainer(trainlib.Trainer):
def __init__(self):
super().__init__(net, dset, val_dset, args, conf["train"], device=device)
self.renderer_state_path = "%s/%s/_renderer" % (
self.args.checkpoints_path,
self.args.name,
)
self.lambda_coarse = conf.get_float("loss.lambda_coarse")
self.lambda_fine = conf.get_float("loss.lambda_fine", 1.0)
print(
"lambda coarse {} and fine {}".format(self.lambda_coarse, self.lambda_fine)
)
self.rgb_coarse_crit = loss.get_rgb_loss(conf["loss.rgb"], True)
fine_loss_conf = conf["loss.rgb"]
if "rgb_fine" in conf["loss"]:
print("using fine loss")
fine_loss_conf = conf["loss.rgb_fine"]
self.rgb_fine_crit = loss.get_rgb_loss(fine_loss_conf, False)
if args.resume:
if os.path.exists(self.renderer_state_path):
renderer.load_state_dict(
torch.load(self.renderer_state_path, map_location=device)
)
self.z_near = dset.z_near
self.z_far = dset.z_far
self.use_bbox = args.no_bbox_step > 0
def post_batch(self, epoch, batch):
renderer.sched_step(args.batch_size)
def extra_save_state(self):
torch.save(renderer.state_dict(), self.renderer_state_path)
def calc_losses(self, data, is_train=True, global_step=0):
if "images" not in data:
return {}
all_images = data["images"].to(device=device) # (SB, NV, 3, H, W)
SB, NV, _, H, W = all_images.shape
all_poses = data["poses"].to(device=device) # (SB, NV, 4, 4)
all_bboxes = data.get("bbox") # (SB, NV, 4) cmin rmin cmax rmax
all_focals = data["focal"] # (SB)
all_c = data.get("c") # (SB)
if self.use_bbox and global_step >= args.no_bbox_step:
self.use_bbox = False
print(">>> Stopped using bbox sampling @ iter", global_step)
if not is_train or not self.use_bbox:
all_bboxes = None
all_rgb_gt = []
all_rays = []
curr_nviews = nviews[torch.randint(0, len(nviews), ()).item()]
if curr_nviews == 1:
image_ord = torch.randint(0, NV, (SB, 1))
else:
image_ord = torch.empty((SB, curr_nviews), dtype=torch.long)
for obj_idx in range(SB):
if all_bboxes is not None:
bboxes = all_bboxes[obj_idx]
images = all_images[obj_idx] # (NV, 3, H, W)
poses = all_poses[obj_idx] # (NV, 4, 4)
focal = all_focals[obj_idx]
c = None
if "c" in data:
c = data["c"][obj_idx]
if curr_nviews > 1:
# Somewhat inefficient, don't know better way
image_ord[obj_idx] = torch.from_numpy(
np.random.choice(NV, curr_nviews, replace=False)
)
images_0to1 = images * 0.5 + 0.5
cam_rays = util.gen_rays(
poses, W, H, focal, self.z_near, self.z_far, c=c
) # (NV, H, W, 8)
rgb_gt_all = images_0to1
rgb_gt_all = (
rgb_gt_all.permute(0, 2, 3, 1).contiguous().reshape(-1, 3)
) # (NV, H, W, 3)
if all_bboxes is not None:
pix = util.bbox_sample(bboxes, args.ray_batch_size)
pix_inds = pix[..., 0] * H * W + pix[..., 1] * W + pix[..., 2]
else:
pix_inds = torch.randint(0, NV * H * W, (args.ray_batch_size,))
rgb_gt = rgb_gt_all[pix_inds] # (ray_batch_size, 3)
rays = cam_rays.view(-1, cam_rays.shape[-1])[pix_inds].to(
device=device
) # (ray_batch_size, 8)
all_rgb_gt.append(rgb_gt)
all_rays.append(rays)
all_rgb_gt = torch.stack(all_rgb_gt) # (SB, ray_batch_size, 3)
all_rays = torch.stack(all_rays) # (SB, ray_batch_size, 8)
image_ord = image_ord.to(device)
src_images = util.batched_index_select_nd(
all_images, image_ord
) # (SB, NS, 3, H, W)
src_poses = util.batched_index_select_nd(all_poses, image_ord) # (SB, NS, 4, 4)
all_bboxes = all_poses = all_images = None
net.encode(
src_images,
src_poses,
all_focals.to(device=device),
c=all_c.to(device=device) if all_c is not None else None,
)
render_dict = DotMap(render_par(all_rays, want_weights=True,))
coarse = render_dict.coarse
fine = render_dict.fine
using_fine = len(fine) > 0
loss_dict = {}
rgb_loss = self.rgb_coarse_crit(coarse.rgb, all_rgb_gt)
loss_dict["rc"] = rgb_loss.item() * self.lambda_coarse
if using_fine:
fine_loss = self.rgb_fine_crit(fine.rgb, all_rgb_gt)
rgb_loss = rgb_loss * self.lambda_coarse + fine_loss * self.lambda_fine
loss_dict["rf"] = fine_loss.item() * self.lambda_fine
loss = rgb_loss
if is_train:
loss.backward()
loss_dict["t"] = loss.item()
return loss_dict
def train_step(self, data, global_step):
return self.calc_losses(data, is_train=True, global_step=global_step)
def eval_step(self, data, global_step):
renderer.eval()
losses = self.calc_losses(data, is_train=False, global_step=global_step)
renderer.train()
return losses
def vis_step(self, data, global_step, idx=None):
if "images" not in data:
return {}
if idx is None:
batch_idx = np.random.randint(0, data["images"].shape[0])
else:
print(idx)
batch_idx = idx
images = data["images"][batch_idx].to(device=device) # (NV, 3, H, W)
poses = data["poses"][batch_idx].to(device=device) # (NV, 4, 4)
focal = data["focal"][batch_idx : batch_idx + 1] # (1)
c = data.get("c")
if c is not None:
c = c[batch_idx : batch_idx + 1] # (1)
NV, _, H, W = images.shape
cam_rays = util.gen_rays(
poses, W, H, focal, self.z_near, self.z_far, c=c
) # (NV, H, W, 8)
images_0to1 = images * 0.5 + 0.5 # (NV, 3, H, W)
curr_nviews = nviews[torch.randint(0, len(nviews), (1,)).item()]
views_src = np.sort(np.random.choice(NV, curr_nviews, replace=False))
view_dest = np.random.randint(0, NV - curr_nviews)
for vs in range(curr_nviews):
view_dest += view_dest >= views_src[vs]
views_src = torch.from_numpy(views_src)
# set renderer net to eval mode
renderer.eval()
source_views = (
images_0to1[views_src]
.permute(0, 2, 3, 1)
.cpu()
.numpy()
.reshape(-1, H, W, 3)
)
gt = images_0to1[view_dest].permute(1, 2, 0).cpu().numpy().reshape(H, W, 3)
with torch.no_grad():
test_rays = cam_rays[view_dest] # (H, W, 8)
test_images = images[views_src] # (NS, 3, H, W)
net.encode(
test_images.unsqueeze(0),
poses[views_src].unsqueeze(0),
focal.to(device=device),
c=c.to(device=device) if c is not None else None,
)
test_rays = test_rays.reshape(1, H * W, -1)
render_dict = DotMap(render_par(test_rays, want_weights=True))
coarse = render_dict.coarse
fine = render_dict.fine
using_fine = len(fine) > 0
alpha_coarse_np = coarse.weights[0].sum(dim=-1).cpu().numpy().reshape(H, W)
rgb_coarse_np = coarse.rgb[0].cpu().numpy().reshape(H, W, 3)
depth_coarse_np = coarse.depth[0].cpu().numpy().reshape(H, W)
if using_fine:
alpha_fine_np = fine.weights[0].sum(dim=1).cpu().numpy().reshape(H, W)
depth_fine_np = fine.depth[0].cpu().numpy().reshape(H, W)
rgb_fine_np = fine.rgb[0].cpu().numpy().reshape(H, W, 3)
print("c rgb min {} max {}".format(rgb_coarse_np.min(), rgb_coarse_np.max()))
print(
"c alpha min {}, max {}".format(
alpha_coarse_np.min(), alpha_coarse_np.max()
)
)
alpha_coarse_cmap = util.cmap(alpha_coarse_np) / 255
depth_coarse_cmap = util.cmap(depth_coarse_np) / 255
vis_list = [
*source_views,
gt,
depth_coarse_cmap,
rgb_coarse_np,
alpha_coarse_cmap,
]
vis_coarse = np.hstack(vis_list)
vis = vis_coarse
if using_fine:
print("f rgb min {} max {}".format(rgb_fine_np.min(), rgb_fine_np.max()))
print(
"f alpha min {}, max {}".format(
alpha_fine_np.min(), alpha_fine_np.max()
)
)
depth_fine_cmap = util.cmap(depth_fine_np) / 255
alpha_fine_cmap = util.cmap(alpha_fine_np) / 255
vis_list = [
*source_views,
gt,
depth_fine_cmap,
rgb_fine_np,
alpha_fine_cmap,
]
vis_fine = np.hstack(vis_list)
vis = np.vstack((vis_coarse, vis_fine))
rgb_psnr = rgb_fine_np
else:
rgb_psnr = rgb_coarse_np
psnr = util.psnr(rgb_psnr, gt)
vals = {"psnr": psnr}
print("psnr", psnr)
# set the renderer network back to train mode
renderer.train()
return vis, vals