利用手写nerf的代码,修改为tensorRF的实现,简略版TensorRF源代码实现。

https://github.com/shouxieai/nerf.git 这里面有数据集。从这个网址下载代码,然后把以下代码替换为里面的tran-nerf.py就是tensorRF的简陋版实现。

代码的执行方式

# Get Start
1. Run train `python train-nerf.py --half-resolution`

# Run Demo
1. Run `python train-nerf.py --make-video360`
    - Produce a video with 360 degree rendering

# Demo
![](rotate360/008.png)

# Reference
1. https://github.com/bmild/nerf
2. https://github.com/yenchenlin/nerf-pytorch

vm组件定义后,获得的体积密度和颜色特征,结合视觉方向添加一个小型神经网络完成回归

代码如下,要去替换下从github下载的代码。反正只有一个python文件。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
import cv2
import os
import json
import argparse
import imageio

# 主要知识点
# 1. 位置编码,Positional Encoding
#    - 对于输入的x、y、z坐标,因为是连续的无法进行区分,因此采用ff特征,即傅立叶特征进行编码
#    - 编码为cos、sin不同频率的叠加,使得连续值可以具有足够的区分性
# 2. 视图独立性,View Dependent
#    - 输入不仅仅是光线采样点的x、y、z坐标,加上了视图依赖,即x、y、z、theta、pi,5d输入,此时多了射线所在视图
# 3. 分层采样,Hierarchical sampling
#    - 将渲染分为两级,由于第一级别的模型是均匀采样,而实际会有很多无效的采样(即对颜色没有贡献的区域会占比太多),在模型
#       中看来,就是某些点的梯度为0,对模型训练没有贡献
#    - 因此采用两级模型,model、fine,model模型使用均匀采样,推断后得到weights的分布,通过对weights分布进行重采样,使得采样点
#       更加集中在更重要的区域,今儿使得参与训练的点大都是有效的点。所以model作为一级推理,fine则推理重采样后的点
#
# x. 拓展,对于射线的方向和原点的理解,需要具有基本的3d变换知识,建议看GAMES101的前5章补充知识
#    PSNR是峰值信噪比,表示重建的逼真程度
# 这三个环节有了,效果就会非常逼真,但是某些细节上还是存在不足。另外训练时间非常关键

class BlenderProvider:
    def __init__(self, root, transforms_file, half_resolution=True):

        self.meta            = json.load(open(os.path.join(root, transforms_file), "r"))
        self.root            = root
        self.frames          = self.meta["frames"]
        self.images          = []
        self.poses           = []
        self.camera_angle_x  = self.meta["camera_angle_x"]
        
        for frame in self.frames:
            image_file = os.path.join(self.root, frame["file_path"] + ".png")
            image      = imageio.imread(image_file)

            if half_resolution:
                image  = cv2.resize(image, dsize=None, fx=0.5, fy=0.5, interpolation=cv2.INTER_AREA)

            self.images.append(image)
            self.poses.append(frame["transform_matrix"])

        self.poses  = np.stack(self.poses)
        self.images = (np.stack(self.images) / 255.0).astype(np.float32)
        self.width  = self.images.shape[2]
        self.height = self.images.shape[1]
        self.focal  = 0.5 * self.width / np.tan(0.5 * self.camera_angle_x)

        alpha       = self.images[..., [3]]
        rgb         = self.images[..., :3]
        self.images = rgb * alpha + (1 - alpha)


class NeRFDataset:
    def __init__(self, provider, batch_size=1024, device="cuda"):

        self.images        = provider.images
        self.poses         = provider.poses
        self.focal         = provider.focal
        self.width         = provider.width
        self.height        = provider.height
        self.batch_size    = batch_size
        self.num_image     = len(self.images)
        self.precrop_iters = 500
        self.precrop_frac  = 0.5
        self.niter         = 0
        self.device        = device
        self.initialize()


    def initialize(self):

        warange = torch.arange(self.width,  dtype=torch.float32, device=self.device)
        harange = torch.arange(self.height, dtype=torch.float32, device=self.device)
        y, x = torch.meshgrid(harange, warange)

        self.transformed_x = (x - self.width * 0.5) / self.focal
        self.transformed_y = (y - self.height * 0.5) / self.focal

        # pre center crop
        self.precrop_index = torch.arange(self.width * self.height).view(self.height, self.width)

        dH = int(self.height // 2 * self.precrop_frac)
        dW = int(self.width  // 2 * self.precrop_frac)
        self.precrop_index = self.precrop_index[
            self.height // 2 - dH:self.height // 2 + dH, 
            self.width  // 2 - dW:self.width  // 2 + dW
        ].reshape(-1)

        poses = torch.cuda.FloatTensor(self.poses, device=self.device)
        all_ray_dirs, all_ray_origins = [], []

        for i in range(len(self.images)):
            ray_dirs, ray_origins = self.make_rays(self.transformed_x, self.transformed_y, poses[i])
            all_ray_dirs.append(ray_dirs)
            all_ray_origins.append(ray_origins)

        self.all_ray_dirs    = torch.stack(all_ray_dirs, dim=0)
        self.all_ray_origins = torch.stack(all_ray_origins, dim=0)
        self.images          = torch.cuda.FloatTensor(self.images, device=self.device).view(self.num_image, -1, 3)
        

    def __getitem__(self, index):
        self.niter += 1

        ray_dirs   = self.all_ray_dirs[index]
        ray_oris   = self.all_ray_origins[index]
        img_pixels = self.images[index]
        if self.niter < self.precrop_iters:
            ray_dirs   = ray_dirs[self.precrop_index]
            ray_oris   = ray_oris[self.precrop_index]
            img_pixels = img_pixels[self.precrop_index]

        nrays          = self.batch_size
        select_inds    = np.random.choice(ray_dirs.shape[0], size=[nrays], replace=False)
        ray_dirs       = ray_dirs[select_inds]
        ray_oris       = ray_oris[select_inds]
        img_pixels     = img_pixels[select_inds]

        # dirs是指:direction
        # ori是指: origin
        return ray_dirs, ray_oris, img_pixels


    def __len__(self):
        return self.num_image


    def make_rays(self, x, y, pose):

        # 100, 100, 3
        # 坐标系在-y,-z方向上
        directions    = torch.stack([x, -y, -torch.ones_like(x)], dim=-1)
        camera_matrix = pose[:3, :3]
        
        # 10000 x 3
        ray_dirs = directions.reshape(-1, 3) @ camera_matrix.T
        ray_origin = pose[:3, 3].view(1, 3).repeat(len(ray_dirs), 1)
        return ray_dirs, ray_origin


    def get_test_item(self, index=0):

        ray_dirs   = self.all_ray_dirs[index]
        ray_oris   = self.all_ray_origins[index]
        img_pixels = self.images[index]

        for i in range(0, len(ray_dirs), self.batch_size):
            yield ray_dirs[i:i+self.batch_size], ray_oris[i:i+self.batch_size], img_pixels[i:i+self.batch_size]


    def get_rotate_360_rays(self):
        def trans_t(t):
            return np.array([
                [1,0,0,0],
                [0,1,0,0],
                [0,0,1,t],
                [0,0,0,1],
            ], dtype=np.float32)

        def rot_phi(phi):
            return np.array([
                [1,0,0,0],
                [0,np.cos(phi),-np.sin(phi),0],
                [0,np.sin(phi), np.cos(phi),0],
                [0,0,0,1],
            ], dtype=np.float32)

        def rot_theta(th) : 
            return np.array([
                [np.cos(th),0,-np.sin(th),0],
                [0,1,0,0],
                [np.sin(th),0, np.cos(th),0],
                [0,0,0,1],
            ], dtype=np.float32)

        def pose_spherical(theta, phi, radius):
            c2w = trans_t(radius)
            c2w = rot_phi(phi/180.*np.pi) @ c2w
            c2w = rot_theta(theta/180.*np.pi) @ c2w
            c2w = np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]) @ c2w
            return c2w

        for th in np.linspace(-180., 180., 41, endpoint=False):
            pose = torch.cuda.FloatTensor(pose_spherical(th, -30., 4.), device=self.device)

            def genfunc():
                ray_dirs, ray_origins = self.make_rays(self.transformed_x, self.transformed_y, pose)
                for i in range(0, len(ray_dirs), 1024):
                    yield ray_dirs[i:i+1024], ray_origins[i:i+1024]

            yield genfunc


# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False):
    # Get pdf
    device = weights.device
    weights = weights + 1e-5 # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)  # (batch, len(bins))
    
    # Take uniform samples
    if det:
        u = torch.linspace(0., 1., steps=N_samples, device=device)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])

    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds-1), inds-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[...,1]-cdf_g[...,0])
    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
    t = (u-cdf_g[...,0])/denom
    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
    return samples


def sample_rays(ray_directions, ray_origins, sample_z_vals):

    nrays = len(ray_origins)
    sample_z_vals = sample_z_vals.repeat(nrays, 1)
    rays = ray_origins[:, None, :] + ray_directions[:, None, :] * sample_z_vals[..., None]
    return rays, sample_z_vals


def sample_viewdirs(ray_directions):
    return ray_directions / torch.norm(ray_directions, dim=-1, keepdim=True)
    

def predict_to_rgb(sigma, rgb, z_vals, raydirs, white_background=False):

    device         = sigma.device
    delta_prefix   = z_vals[..., 1:] - z_vals[..., :-1]
    delta_addition = torch.full((z_vals.size(0), 1), 1e10, device=device)
    delta = torch.cat([delta_prefix, delta_addition], dim=-1)
    delta = delta * torch.norm(raydirs[..., None, :], dim=-1)

    alpha    = 1.0 - torch.exp(-sigma * delta)
    exp_term = 1.0 - alpha
    epsilon  = 1e-10
    exp_addition = torch.ones(exp_term.size(0), 1, device=device)
    exp_term = torch.cat([exp_addition, exp_term + epsilon], dim=-1)
    transmittance = torch.cumprod(exp_term, axis=-1)[..., :-1]

    weights       = alpha * transmittance
    rgb           = torch.sum(weights[..., None] * rgb, dim=-2)
    depth         = torch.sum(weights * z_vals, dim=-1)
    acc_map       = torch.sum(weights, -1)

    if white_background:
        rgb       = rgb + (1.0 - acc_map[..., None])
    return rgb, depth, acc_map, weights


def render_rays(model, fine, raydirs, rayoris, sample_z_vals, importance=0, white_background=False):

    rays, z_vals = sample_rays(raydirs, rayoris, sample_z_vals)
    view_dirs    = sample_viewdirs(raydirs)

    sigma, rgb = model(rays, view_dirs)
    sigma      = sigma.squeeze(dim=-1)
    rgb1, depth1, acc_map1, weights1 = predict_to_rgb(sigma, rgb, z_vals, raydirs, white_background)

    # 使用weights1进行重采样
    z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
    z_samples  = sample_pdf(z_vals_mid, weights1[...,1:-1], importance, det=True)
    z_samples  = z_samples.detach()

    z_vals, _  = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
    rays       = rayoris[...,None,:] + raydirs[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
    sigma, rgb = fine(rays, view_dirs)
    sigma      = sigma.squeeze(dim=-1)

    # 第二次重采样的预测才是最终结果,这是论文中,分层采样环节(Hierarchical sampling)
    rgb2, depth2, acc_map2, weights2 = predict_to_rgb(sigma, rgb, z_vals, raydirs, white_background)
    return rgb1, rgb2

# 无视图独立性的head
class NoViewDirHead(nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()

        self.head = nn.Linear(ninput, noutput)
    
    def forward(self, x, view_dirs):
        
        x = self.head(x)
        rgb   = x[..., :3].sigmoid()
        sigma = x[..., 3].relu()
        return sigma, rgb

# 视图独立性的head
class ViewDenepdentHead(nn.Module):
    def __init__(self, ninput, nview):
        super().__init__()

        self.feature = nn.Linear(ninput, ninput)#总体输入
        self.view_fc = nn.Linear(ninput + nview, ninput // 2) #256-》128
        self.alpha = nn.Linear(ninput, 1) #体积密度
        self.rgb = nn.Linear(ninput // 2, 3) # 128-》rgb
    
    def forward(self, x, view_dirs):
        
        feature = self.feature(x)
        sigma   = self.alpha(x).relu()
        feature = torch.cat([feature, view_dirs], dim=-1)
        feature = self.view_fc(feature).relu()
        rgb     = self.rgb(feature).sigmoid()
        
        
        return sigma, rgb

# 视图独立性的head
class ViewDenepdentHead_tensorRF(nn.Module):
    def __init__(self, ninput, nview):#27 +1 +1
        super().__init__()
        ninput = 21 + 567 +27 

        self.view_fc = nn.Linear(ninput , ninput // 2) #28-》14
 
        self.rgb = nn.Linear(ninput // 2, 3) # 14-》rgb

    def forward(self, sigma_value,x_old, view_dirs,sigma_feature,app_features):
        
        
        feature = torch.cat([sigma_feature, app_features], dim=-1)                        
        feature = torch.cat([feature, view_dirs], dim=-1)
                
        feature = self.view_fc(feature).relu()
        rgb     = self.rgb(feature).sigmoid()
        sigma = sigma_value
        
        C,H,W =x_old
        return  sigma.reshape(C,H,1),rgb.reshape(x_old)

# 位置编码实现
class Embedder(nn.Module):
    def __init__(self, positional_encoding_dim):
        super().__init__()
        self.positional_encoding_dim = positional_encoding_dim

    def forward(self, x):
        
        positions = [x]
 
        for i in range(self.positional_encoding_dim):
            for fn in [torch.sin, torch.cos]:
                positions.append(fn((2.0 ** i) * x))

        return torch.cat(positions, dim=-1)



# 基本模型结构
class NeRF(nn.Module):
    def __init__(self, x_pedim=10, nwidth=256, ndepth=8, view_pedim=4):
        super().__init__()
        
        xdim         = (x_pedim * 2 + 1) * 3

        layers       =  []
        layers_in    = [nwidth] * ndepth
        layers_in[0] = xdim
        layers_in[5] = nwidth + xdim
        #实现vm分解
        self.density_n_comp = [16,16,16]
        self.app_n_comp =[48,48,48]
        self.gridSize =  np.array([128,128,128])
        self.matMode = [[0,1], [0,2], [1,2]]
        self.vecMode =  [2, 1, 0]
        self.comp_w = [1,1,1]
        self.app_dim = 27
        device = "cuda:0"
        self.sigma_feature = []
        self.app_features = []
        self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device)
        self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device)
     
        self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device)#颜色特征送入神经网络 48+48+48  -> 27
        self.tensorf = 1
        # 模型中特定层[5]会存在concat
        for i in range(ndepth):
            layers.append(nn.Linear(layers_in[i], nwidth))
        
        if view_pedim > 0:
            view_dim = (view_pedim * 2 + 1) * 3
            self.view_embed = Embedder(view_pedim)
            if self.tensorf == 1:        
                                   
                self.head = ViewDenepdentHead_tensorRF(nwidth, view_dim)
            else:
                self.head = ViewDenepdentHead(nwidth, view_dim)
        else:
            self.head = NoViewDirHead(nwidth, 4)
       
        self.xembed = Embedder(x_pedim)
        if self.tensorf == 0: 
            self.layers = nn.Sequential(*layers)

    def compute_densityfeature(self, xyz_sampled):

        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = (torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1))
        coordinate_line = coordinate_line.detach().view(3, -1, 1, 2)
        sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device)
        for idx_plane in range(len(self.density_plane)):
            plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]],
                                                align_corners=True).view(-1, *xyz_sampled.shape[:1])
            line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]],
                                            align_corners=True).view(-1, *xyz_sampled.shape[:1])
            sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0)
        return sigma_feature
    #[16, 16, 16] tensor([128, 128, 128], device='cuda:0') 0.1 cuda
    def init_one_svd(self, n_component=[16,16,16], gridSize=[128,128,128], scale=0.1, device="cuda:0"):
        plane_coef, line_coef = [], []
        matMode = [[0,1], [0,2], [1,2]]
        vecMode =  [2, 1, 0]
        comp_w = [1,1,1]
        for i in range(len(vecMode)):
            vec_id = vecMode[i]
            mat_id_0, mat_id_1 = matMode[i]
            plane_coef.append(torch.nn.Parameter(
                scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0]))))  #
            print(f"plane matrix  components dim[{i}] 1",n_component[i], gridSize[mat_id_1].item(), gridSize[mat_id_0].item())
            line_coef.append(
                torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1))))
            print(f"line vector  components dim[{i}] 1",n_component[i], gridSize[vec_id].item())

        return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device)

    def compute_appfeature(self, xyz_sampled):
        
        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)

        plane_coef_point,line_coef_point = [],[]
        for idx_plane in range(len(self.app_plane)):
            plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]],
                                                align_corners=True).view(-1, *xyz_sampled.shape[:1]))
            line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]],
                                            align_corners=True).view(-1, *xyz_sampled.shape[:1]))
        plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)



        return self.basis_mat((plane_coef_point * line_coef_point).T)

    
    def forward(self, x, view_dirs):#这里会传入视角参数
        
        xshape = x.shape
        x_old = xshape
        if self.tensorf == 0:
            x = self.xembed(x) #暂时不使用傅里叶编码了
            
            if self.view_embed is not None:
                view_dirs = view_dirs[:, None].expand(xshape)  #把这个点复制了64份,难不成这里是相机点???
                view_dirs = self.view_embed(view_dirs)      
            
            raw_x = x
            for i, layer in enumerate(self.layers):
                x = torch.relu(layer(x))
                
                if i == 4:
                    x = torch.cat([x, raw_x], axis=-1)

            return self.head(x, view_dirs)
        else:
            C, H, W = x.shape   
            
            new_shape = (C * H, W)  
            x = x.view(new_shape) 
            
            self.sigma_feature = self.compute_densityfeature(x)
            self.sigma_feature = self.sigma_feature.unsqueeze(1)
            
            self.app_features = self.compute_appfeature(x)

            #print("sigma__1",self.sigma_feature.shape)
            sigma_feature_em = self.xembed(self.sigma_feature)
            #print("sigma__2",sigma_feature_em.shape)

            #print("app__1",self.app_features.shape)
            app_features_em = self.xembed(self.app_features)
            #print("app__2",app_features_em.shape)

            if self.view_embed is not None:
                #print("zhouluo_1",view_dirs.shape)
                view_dirs = view_dirs[:, None].expand(xshape)  
                #print("zhouluo_2",view_dirs.shape)
                view_dirs = self.view_embed(view_dirs)   
                #print("zhouluo_3",view_dirs.shape)                             
                C, H, W = view_dirs.shape
                new_shape = (C * H, W)  
                view_dirs = view_dirs.reshape(new_shape)
                #print("zhouluo_4",view_dirs.shape)  
                
            
            
            
            return self.head(self.sigma_feature,x_old,view_dirs,sigma_feature_em,app_features_em)

            



def train():

    pbar     = tqdm(range(1, maxiters))
    for global_step in pbar:

        idx   = np.random.randint(0, len(trainset))
        raydirs, rayoris, imagepixels = trainset[idx]

        rgb1, rgb2 = render_rays(model, fine, raydirs, rayoris, sample_z_vals, importance, white_background)
        loss1 = ((rgb1 - imagepixels)**2).mean()
        loss2 = ((rgb2 - imagepixels)**2).mean()
        psnr  = -10. * torch.log(loss2.detach()) / np.log(10.)
        loss  = loss1 + loss2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_description(f"{global_step} / {maxiters}, Loss: {loss.item():.6f}, PSNR: {psnr.item():.6f}")

        decay_rate = 0.1
        new_lrate  = lrate * (decay_rate ** (global_step / lrate_decay))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate

        if global_step % 500 == 0 or global_step == 500:

            imgpath = f"imgs/{global_step:02d}.png"
            pthpath = f"ckpt/{global_step:02d}.pth"
            model.eval()
            with torch.no_grad():
                rgbs, imgpixels = [], []
                for raydirs, rayoris, imagepixels in trainset.get_test_item():

                    rgb1, rgb2  = render_rays(model, fine, raydirs, rayoris, sample_z_vals, importance, white_background)
                    rgbs.append(rgb2)
                    imgpixels.append(imagepixels)

                rgb       = torch.cat(rgbs, dim=0)
                imgpixels = torch.cat(imgpixels, dim=0)
                loss      = ((rgb - imgpixels)**2).mean()
                psnr      = -10. * torch.log(loss) / np.log(10.)

                print(f"Save image {imgpath}, Loss: {loss.item():.6f}, PSNR: {psnr.item():.6f}")
            model.train()
            
            temp_image = (rgb.view(height, width, 3).cpu().numpy() * 255).astype(np.uint8)
            cv2.imwrite(imgpath, temp_image[..., ::-1])
            torch.save([model.state_dict(), fine.state_dict()], pthpath)


def make_video360():

    mstate, fstate = torch.load(args.ckpt, map_location="cpu")
    model.load_state_dict(mstate)
    fine.load_state_dict(fstate)
    model.eval()
    fine.eval()
    imagelist = []

    for i, gfn in tqdm(enumerate(trainset.get_rotate_360_rays()), desc="Rendering"):

        with torch.no_grad():
            rgbs = []
            for raydirs, rayoris in gfn():
                rgb1, rgb2 = render_rays(model, fine, raydirs, rayoris, sample_z_vals, importance, white_background)
                rgbs.append(rgb2)

            rgb = torch.cat(rgbs, dim=0)
        
        rgb  = (rgb.view(height, width, 3).cpu().numpy() * 255).astype(np.uint8)
        file = f"rotate360/{i:03d}.png"

        print(f"Rendering to {file}")
        cv2.imwrite(file, rgb[..., ::-1])
        imagelist.append(rgb)

    video_file = f"videos/rotate360.mp4"
    print(f"Write imagelist to video file {video_file}")
    imageio.mimwrite(video_file, imagelist, fps=30, quality=10)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--datadir", type=str, default='data/nerf_synthetic/lego', help='input data directory')
    parser.add_argument("--make-video360", action="store_true", help="make 360 rotation video")
    parser.add_argument("--half-resolution", action="store_true", help="use half resolution")
    parser.add_argument("--ckpt", default="ckpt/2000.pth", type=str, help="model file used to make 360 rotation video")
    args = parser.parse_args()

    device      = "cuda:0"
    maxiters    = 5000 + 1
    batch_size  = 1024*4  #这可能占据15G显存
    lrate_decay = 500 * 1000
    lrate       = 5e-4
    importance  = 128
    num_samples = 64                    # 每个光线的采样数量
    positional_encoding_dim = 10        # 位置编码维度
    view_encoding_dim       = 4         # View Dependent对应的位置编码维度
    white_background        = True      # 图片背景是白色的
    half_resolution         = args.half_resolution    # 只进行一半分辨率的重建(400x400),False表示(800x800)分辨率
    sample_z_vals           = torch.linspace(2.0, 6.0, num_samples, device=device).view(1, num_samples)

    model = NeRF(
        x_pedim    = positional_encoding_dim,
        view_pedim = view_encoding_dim
    ).to(device)
    params = list(model.parameters())
    
    # 使用model产生的权重进行重采样,然后再推理,所以这个才是效果更好的模型
    fine = NeRF(
        x_pedim    = positional_encoding_dim,
        view_pedim = view_encoding_dim
    ).to(device)
    params.extend(list(fine.parameters()))

    optimizer = optim.Adam(params, lrate)
    os.makedirs("imgs",      exist_ok=True)
    os.makedirs("rotate360", exist_ok=True)
    os.makedirs("videos",    exist_ok=True)
    os.makedirs("ckpt",      exist_ok=True)

    print(model)

    provider = BlenderProvider("data/nerf_synthetic/lego", "transforms_train.json", half_resolution)
    trainset = NeRFDataset(provider, batch_size, device)
    width    = trainset.width
    height   = trainset.height

    if args.make_video360:
        make_video360()
    else:
        train()

    print("Program done.")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值