个人代码学习笔记

一、demo.py

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"

import sys
import time
from argparse import ArgumentParser
from tqdm import tqdm
import wandb
import yaml
import torch

from core.utils import calibration, dataset
from core.utils.logging_utils import init_cache, Log
from core.slam import SLAM

def main():
    # Set up command line argument parser设置命令行参数解析器
    parser = ArgumentParser(description="MCSLAM Demo")
    parser.add_argument("--config", type=str, required=True, help="Path to the configuration file")
    parser.add_argument("--eval", action="store_false", default=False, help="Evaluation mode")
    
    args = parser.parse_args()
    torch.multiprocessing.set_start_method('spawn')

    with open(args.config, "r") as yml:
        config = yaml.safe_load(yml)
    
    # load calibration负载校准
    calib = calibration.Calibration(config["Calibration"])
    
    # load dataset加载数据集
    db = dataset.load_dataset(config["Dataset"], calib)
    
    # initialize the SLAM system
    slam = SLAM(config, calib)
    
    # run the demo跑模型
    for frame in tqdm(db):
        slam.feed_frame(frame)
        
    slam.finalize()
    
    # All done完成
    Log("Done.")
    
if __name__ == "__main__":
    main()

二、backened.py

import random
import time

import torch
import torch.multiprocessing as mp

from .viser import gaussians_viz
from .mapper import Mapper
    
class Backend:
    def __init__(self, config, window, calib):
        self.config = config
        self.device = "cuda"
        self.window = window
        self.calib = calib
        #创建一个 Mapper 类的实例,并传入配置
        self.mapper = Mapper(config)
        self.backend_process = mp.Process(target=self.mapper.run)
        self.backend_process.start()
        
        if self.config["viz_enalbe"]:
            self.viz_process = mp.Process(target=gaussians_viz, args=(self.mapper, self.device, self.config['port_num']))
            self.viz_process.start()
    #定义 finalize 方法,用于清理和停止后台进程。
    def finalize(self):
        self.mapper.backend_queue.put(["stop"])
        self.backend_process.join()
        if self.config["viz_enalbe"]:
            self.viz_process.join()

三、mapper.py

import random
import time

import torch
import torch.multiprocessing as mp
from tqdm import tqdm
from munch import munchify

from .gaussian_splatting.gaussian_model import GaussianModel
from .gaussian_splatting.gaussian_renderer import render
from .gaussian_splatting.utils.loss_utils import l1_loss, ssim

#计算输入图像和深度图像在特定视点下的 RGB 损失
def get_loss_mapping_rgb(config, image, depth, viewpoint):
    gt_image = viewpoint.original_image.cuda()
#从 viewpoint 中获取原始图像,并将其移动到 GPU 上。
#viewpoint.original_image 是 ground truth 图像,表示我们要与生成图像进行比较的真实图像。
    _, h, w = gt_image.shape
    mask_shape = (1, h, w) #定义一个遮罩的形状 mask_shape 为 (1, h, w)。
    rgb_boundary_threshold = config["Training"]["rgb_boundary_threshold"]

 #对 ground truth 图像在通道维度上进行求和,然后与阈值进行比较,生成一个布尔遮罩 rgb_pixel_mask。这个遮罩只保留像素值总和超过阈值的像素。
    rgb_pixel_mask = (gt_image.sum(dim=0) > rgb_boundary_threshold).view(*mask_shape)
    l1_rgb = torch.abs(image * rgb_pixel_mask - gt_image * rgb_pixel_mask)

    return l1_rgb.mean() #返回 L1 损失的平均值。这个平均值表示输入图像和 ground truth 图像在 RGB 颜色空间上的平均像素差异。


#计算输入的 RGB 图像和深度图像与 ground truth 图像和深度图像之间的混合损失(RGB 和深度)
def get_loss_mapping_rgbd(config, image, depth, viewpoint, initialization=False):
    #alpha 是 RGB 和深度损失之间的权重系数。
    alpha = config["Training"]["alpha"] if "alpha" in config["Training"] else 0.95
    rgb_boundary_threshold = config["Training"]["rgb_boundary_threshold"]

    gt_image = viewpoint.original_image.cuda()

    gt_depth = torch.from_numpy(viewpoint.depth).to(dtype=torch.float32, device=image.device)[None]
#创建 RGB 像素遮罩 rgb_pixel_mask:对 ground truth 图像在通道维度上进行求和,然后与阈值进行比较,生成一个布尔遮罩。这个遮罩只保留像素值总和超过阈值的像素。
    rgb_pixel_mask = (gt_image.sum(dim=0) > rgb_boundary_threshold).view(*depth.shape)
    
    #创建深度像素遮罩 depth_pixel_mask:将 ground truth 深度图像中深度值大于 0.01 的像素保留。
    depth_pixel_mask = (gt_depth > 0.01).view(*depth.shape)


    l1_rgb = torch.abs(image * rgb_pixel_mask - gt_image * rgb_pixel_mask)
    l1_depth = torch.abs(depth * depth_pixel_mask - gt_depth * depth_pixel_mask)

    #返回混合损失值,即 RGB 损失的加权平均和深度损失的加权平均之和。alpha 控制 RGB 损失的权重,1 - alpha 控制深度损失的权重。
    return alpha * l1_rgb.mean() + (1 - alpha) * l1_depth.mean()

#计算基于配置的图像和深度图像的损失,可以选择 RGB 损失或 RGB-D 损失。
def get_loss_mapping(config, image, depth, viewpoint, opacity, initialization=False):
    if initialization:
        image_ab = image
    else:
        image_ab = (torch.exp(viewpoint.exposure_a)) * image + viewpoint.exposure_b
    if config["Training"]["monocular"]:
        return get_loss_mapping_rgb(config, image_ab, depth, viewpoint)
    return get_loss_mapping_rgbd(config, image_ab, depth, viewpoint)


class Mapper(mp.Process):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.device = "cuda"
        self.dtype = torch.float32
        
        self.backend_queue = mp.Queue() #创建一个多进程队列 self.backend_queue。
        self.pause = False
        self.iteration_count = 0
        self.initialized = False
        
        self.model_params = munchify(config["model_params"])
        self.opt_params = munchify(config["opt_params"])
        self.pipeline_params = munchify(config["pipeline_params"])
        
        #如果使用球谐函数,设置球谐函数的度 sh_degree 为 3,否则为 0。
        self.use_spherical_harmonics = self.config["Training"]["spherical_harmonics"]
        self.model_params.sh_degree = 3 if self.use_spherical_harmonics else 0
        
        self.gaussians = GaussianModel(self.model_params.sh_degree, config=self.config)
        self.gaussians.init_lr(6.0)
        self.gaussians.training_setup(self.opt_params)
        self.background = torch.tensor([0, 0, 0], dtype=self.dtype, device=self.device)
        
        self.cameras_extent = None
        self.frontend_queue = None
        
        self.last_sent = 0
        self.occ_aware_visibility = {}
        self.viewpoints = {}
        self.current_window = []
        
        self.keyframe_optimizers = None
        

    # 用于设置 Mapper 类的超参数,它从配置中读取相关的参数,并将它们赋值给实例变量
    def set_hyperparams(self):
        self.save_results = self.config["Results"]["save_results"]

        self.init_itr_num = self.config["Training"]["init_itr_num"]
        self.init_gaussian_update = self.config["Training"]["init_gaussian_update"]
        self.init_gaussian_reset = self.config["Training"]["init_gaussian_reset"]
        self.init_gaussian_th = self.config["Training"]["init_gaussian_th"]
        self.init_gaussian_extent = (
            self.cameras_extent * self.config["Training"]["init_gaussian_extent"]
        )
        self.mapping_itr_num = self.config["Training"]["mapping_itr_num"]
        self.gaussian_update_every = self.config["Training"]["gaussian_update_every"]
        self.gaussian_update_offset = self.config["Training"]["gaussian_update_offset"]
        self.gaussian_th = self.config["Training"]["gaussian_th"]
        self.gaussian_extent = (
            self.cameras_extent * self.config["Training"]["gaussian_extent"]
        )
        self.gaussian_reset = self.config["Training"]["gaussian_reset"]
        self.size_threshold = self.config["Training"]["size_threshold"]
        self.window_size = self.config["Training"]["window_size"]
    
    #初始化地图,通过迭代进行初步的映射过程
    def initialize_map(self, viewpoint):
        for mapping_iteration in range(self.init_itr_num):
            self.iteration_count += 1
            render_pkg = render(
                viewpoint, self.gaussians, self.pipeline_params, self.background
            )
            #从渲染结果字典 render_pkg 中提取多个值,包括渲染的图像、视点空间点张量、可见性过滤器、半径、深度、透明度和触碰点数。
            (
                image,
                viewspace_point_tensor,
                visibility_filter,
                radii,
                depth,
                opacity,
                n_touched,
            ) = (
                render_pkg["render"],
                render_pkg["viewspace_points"],
                render_pkg["visibility_filter"],
                render_pkg["radii"],
                render_pkg["depth"],
                render_pkg["opacity"],
                render_pkg["n_touched"],
            )
#调用 get_loss_mapping 函数计算初始损失 loss_init,传入配置 self.config、渲染的图像 image、深度图 depth、视点 viewpoint、透明度 opacity,并将 initialization 标志设为 True。
            loss_init = get_loss_mapping(
                self.config, image, depth, viewpoint, opacity, initialization=True
            )

            #对损失 loss_init 进行反向传播,计算梯度。
            loss_init.backward()

            #在不进行梯度计算的情况下进行模型更新
            with torch.no_grad():
                self.gaussians.max_radii2D[visibility_filter] = torch.max(
                    self.gaussians.max_radii2D[visibility_filter],
                    radii[visibility_filter],
                )
                self.gaussians.add_densification_stats(
                    viewspace_point_tensor, visibility_filter
                )
                #进行高斯模型的密集化和剪枝
                if mapping_iteration % self.init_gaussian_update == 0:
                    self.gaussians.densify_and_prune(
                        self.opt_params.densify_grad_threshold,
                        self.init_gaussian_th,
                        self.init_gaussian_extent,
                        None,
                    )

                if self.iteration_count == self.init_gaussian_reset or (
                    self.iteration_count == self.opt_params.densify_from_iter
                ):
                    self.gaussians.reset_opacity()

                self.gaussians.optimizer.step()
                self.gaussians.optimizer.zero_grad(set_to_none=True)

        Log("Initialized map")
        return render_pkg

    # 用于执行映射操作,通过渲染图像和计算损失来更新高斯模型的参数,同时管理关键帧和高斯的密集化与剪枝
    def map(self, current_window, prune=False, iters=1):
        if len(current_window) == 0:
            return

        viewpoint_stack = [self.viewpoints[kf_idx] for kf_idx in current_window] #创建一个包含当前窗口中视点的堆栈
        random_viewpoint_stack = [] #创建一个空的随机视点堆栈
        frames_to_optimize = self.config["Training"]["pose_window"] #从配置中读取要优化的帧数

        current_window_set = set(current_window)

        #如果视点不在当前窗口中,则将其添加到随机视点堆栈
        for cam_idx, viewpoint in self.viewpoints.items():
            if cam_idx in current_window_set:
                continue
            random_viewpoint_stack.append(viewpoint)

        #参数初始化
        for _ in range(iters):
            self.iteration_count += 1
            self.last_sent += 1

            loss_mapping = 0
            viewspace_point_tensor_acm = []
            visibility_filter_acm = []
            radii_acm = []
            n_touched_acm = []

            keyframes_opt = []

            for cam_idx in range(len(current_window)):
                viewpoint = viewpoint_stack[cam_idx]
                keyframes_opt.append(viewpoint)
                #遍历当前窗口中的所有视点,并对每个视点进行渲染,获取渲染结果 
                render_pkg = render(
                    viewpoint, self.gaussians, self.pipeline_params, self.background
                )
                (
                    image,
                    viewspace_point_tensor,
                    visibility_filter,
                    radii,
                    depth,
                    opacity,
                    n_touched,
                ) = (
                    #从渲染结果中提取各项数据
                    render_pkg["render"],
                    render_pkg["viewspace_points"],
                    render_pkg["visibility_filter"],
                    render_pkg["radii"],
                    render_pkg["depth"],
                    render_pkg["opacity"],
                    render_pkg["n_touched"],
                )

                #计算映射损失 loss_mapping 并进行累加
                loss_mapping += get_loss_mapping(
                    self.config, image, depth, viewpoint, opacity
                )
                viewspace_point_tensor_acm.append(viewspace_point_tensor)
                visibility_filter_acm.append(visibility_filter)
                radii_acm.append(radii)
                n_touched_acm.append(n_touched)

            for cam_idx in torch.randperm(len(random_viewpoint_stack))[:2]:
                #随机选择两个视点进行额外的渲染,并获取渲染结果
                viewpoint = random_viewpoint_stack[cam_idx]
                render_pkg = render(
                    viewpoint, self.gaussians, self.pipeline_params, self.background
                )
                (
                    image,
                    viewspace_point_tensor,
                    visibility_filter,
                    radii,
                    depth,
                    opacity,
                    n_touched,
                ) = (
                    render_pkg["render"],
                    render_pkg["viewspace_points"],
                    render_pkg["visibility_filter"],
                    render_pkg["radii"],
                    render_pkg["depth"],
                    render_pkg["opacity"],
                    render_pkg["n_touched"],
                )
                loss_mapping += get_loss_mapping(
                    self.config, image, depth, viewpoint, opacity
                )
                viewspace_point_tensor_acm.append(viewspace_point_tensor)
                visibility_filter_acm.append(visibility_filter)
                radii_acm.append(radii)

            scaling = self.gaussians.get_scaling
            isotropic_loss = torch.abs(scaling - scaling.mean(dim=1).view(-1, 1)) #计算各向同性损失
            loss_mapping += 10 * isotropic_loss.mean()
            loss_mapping.backward() #对总损失进行反向传播
            gaussian_split = False
            ## Deinsifying / Pruning Gaussians
            with torch.no_grad():
                self.occ_aware_visibility = {} #初始化占用感知可见性字典
                #遍历当前窗口中的所有视点,并更新占用感知可见性
                for idx in range((len(current_window))):
                    kf_idx = current_window[idx]
                    n_touched = n_touched_acm[idx]
                    self.occ_aware_visibility[kf_idx] = (n_touched > 0).long()

                # # compute the visibility of the gaussians
                # # Only prune on the last iteration and when we have full window
                if prune:
                    if len(current_window) == self.config["Training"]["window_size"]:
                        prune_mode = self.config["Training"]["prune_mode"]
                        prune_coviz = 3
                        self.gaussians.n_obs.fill_(0)
                        for window_idx, visibility in self.occ_aware_visibility.items():
                            self.gaussians.n_obs += visibility.cpu()
                        to_prune = None
                        if prune_mode == "odometry":
                            to_prune = self.gaussians.n_obs < 3
                            # make sure we don't split the gaussians, break here.
                        if prune_mode == "slam":
                            # only prune keyframes which are relatively new
                            sorted_window = sorted(current_window, reverse=True)
                            mask = self.gaussians.unique_kfIDs >= sorted_window[2]
                            if not self.initialized:
                                mask = self.gaussians.unique_kfIDs >= 0
                            to_prune = torch.logical_and(
                                self.gaussians.n_obs <= prune_coviz, mask
                            )
                        if to_prune is not None and self.monocular:
                            self.gaussians.prune_points(to_prune.cuda())
                            for idx in range((len(current_window))):
                                current_idx = current_window[idx]
                                self.occ_aware_visibility[current_idx] = (
                                    self.occ_aware_visibility[current_idx][~to_prune]
                                )
                        if not self.initialized:
                            self.initialized = True
                            Log("Initialized SLAM")
                        # # make sure we don't split the gaussians, break here.
                    return False

                #遍历当前窗口中的所有视点,并更新占用感知可见性
                for idx in range(len(viewspace_point_tensor_acm)):
                    self.gaussians.max_radii2D[visibility_filter_acm[idx]] = torch.max(
                        self.gaussians.max_radii2D[visibility_filter_acm[idx]],
                        radii_acm[idx][visibility_filter_acm[idx]],
                    )
                    self.gaussians.add_densification_stats(
                        viewspace_point_tensor_acm[idx], visibility_filter_acm[idx]
                    )

                update_gaussian = (
                    self.iteration_count % self.gaussian_update_every
                    == self.gaussian_update_offset
                )
                if update_gaussian:
                    self.gaussians.densify_and_prune(
                        self.opt_params.densify_grad_threshold,
                        self.gaussian_th,
                        self.gaussian_extent,
                        self.size_threshold,
                    )
                    gaussian_split = True

                ## Opacity reset
                if (self.iteration_count % self.gaussian_reset) == 0 and (
                    not update_gaussian
                ):
                    Log("Resetting the opacity of non-visible Gaussians")
                    self.gaussians.reset_opacity_nonvisible(visibility_filter_acm)
                    gaussian_split = True

                self.gaussians.optimizer.step()
                self.gaussians.optimizer.zero_grad(set_to_none=True)
                self.gaussians.update_learning_rate(self.iteration_count)
                self.keyframe_optimizers.step()
                self.keyframe_optimizers.zero_grad(set_to_none=True)
                # Pose update更新相机位姿
                for cam_idx in range(min(frames_to_optimize, len(current_window))):
                    viewpoint = viewpoint_stack[cam_idx]
                    if viewpoint.uid == 0:
                        continue
                    update_pose(viewpoint)
        return gaussian_split

    # 对当前窗口中的每个关键帧进行渲染,并计算映射损失
    def map2(self, current_window, prune=False, iters=1):
        if len(current_window) == 0:
            return
        
        for _ in range(iters):
            self.iteration_count += 1
            self.last_sent += 1
            
            loss_mapping = 0
            
            for cam_idx in range(len(current_window)):
                keyframe = current_window[cam_idx]
                render_pkg = render(keyframe, self.gaussians, self.pipeline_params, self.background)
                (
                    image,
                    viewspace_point_tensor,
                    visibility_filter,
                    radii,
                    depth,
                    opacity,
                    n_touched,
                ) = (
                    render_pkg["render"],
                    render_pkg["viewspace_points"],
                    render_pkg["visibility_filter"],
                    render_pkg["radii"],
                    render_pkg["depth"],
                    render_pkg["opacity"],
                    render_pkg["n_touched"],
                )
                
                loss_mapping += get_loss_mapping(self.config, image, depth, viewpoint, opacity)
                scaling = self.gaussians.get_scaling
                isotropic_loss = torch.abs(scaling - scaling.mean(dim=1).view(-1, 1))
                loss_mapping += 10 * isotropic_loss.mean()
                loss_mapping.backward()
    
    def initialize(self):
        pass
    
    def run(self):
        while True:
            if self.backend_queue.empty():
                if self.pause:
                    time.sleep(0.01)
                    continue
                
                if not self.initialized and len(self.viewpoints):
                    # for kf in self.window.keyframes:
                    #     self.viewpoints[kf.kf_id] = kf
                    self.map2(self.viewpoints)
                    self.initialized = True
                    continue
                
                if self.initialized:
                    self.map2(self.viewpoints)
            else:
                data = self.backend_queue.get()
                if data[0] == "stop":
                    break
                elif data[0] == "pause":
                    self.pause = True
                elif data[0] == "unpause":
                    self.pause = False
                elif data[0] == "color_refinement":
                    NotImplementedError()
                else:
                    raise Exception("Unprocessed data", data)
        while not self.backend_queue.empty():
            self.backend_queue.get()
        
        return

 四、viser.py

from threading import Thread
import torch
import torch.multiprocessing as mp
import numpy as np
import time
import viser
import viser.transforms as tf
from omegaconf import OmegaConf
import cv2
from collections import deque

from .gaussian_splatting.gaussian_renderer import render
from .gaussian_splatting.utils.camera_utils import Camera

#交互式可视化模块,用于处理Gaussian Splatting的渲染工作
def qvec2rotmat(qvec):#将四元数转换为旋转矩阵。
    return np.array(
        [
            [
                1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
                2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
                2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
            ],
            [
                2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
                1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
                2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
            ],
            [
                2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
                2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
                1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
            ],
        ]
    )

def get_c2w(camera):#相机坐标转换到世界坐标
    c2w = np.eye(4, dtype=np.float32)
    c2w[:3, :3] = qvec2rotmat(camera.wxyz)
    c2w[:3, 3] = camera.position
    return c2w

def get_w2c(camera):#世界坐标转换到相机坐标
    c2w = get_c2w(camera)
    w2c = np.linalg.inv(c2w)
    return w2c

def vfov_to_hfov(vfov_deg, height, width):#视场转换到水平视场
    # http://paulbourke.net/miscellaneous/lens/
    return np.rad2deg(
        2 * np.arctan(width * np.tan(np.deg2rad(vfov_deg) / 2) / height)
    )


def gaussians_viz(mapper, device, viewer_port):
    gaussians_viz.mapper = mapper
    gaussians_viz.device = device

    gaussians_viz.render_times = deque(maxlen=3)#双端队列-记录最近三次的渲染时间
    gaussians_viz.server = viser.ViserServer(port=viewer_port)#创建可视化服务器
    gaussians_viz.reset_view_button = gaussians_viz.server.add_gui_button("Reset View")

    gaussians_viz.need_update = False

    gaussians_viz.pause_training = False
    gaussians_viz.train_viewer_update_period_slider = gaussians_viz.server.add_gui_slider(
        "Train Viewer Update Period",
        min=1,
        max=100,
        step=1,
        initial_value=10,
        disabled=gaussians_viz.pause_training,
    )
    #设置初始化后需要更新的标志和几个交互式的控制滑块、按钮等
    gaussians_viz.pause_training_button = gaussians_viz.server.add_gui_button("Pause Training")
    gaussians_viz.sh_order = gaussians_viz.server.add_gui_slider(
        "SH Order", min=1, max=4, step=1, initial_value=1
    )
    gaussians_viz.resolution_slider = gaussians_viz.server.add_gui_slider(
        "Resolution", min=384, max=4096, step=2, initial_value=1024
    )
    gaussians_viz.near_plane_slider = gaussians_viz.server.add_gui_slider(
        "Near", min=0.1, max=30, step=0.5, initial_value=0.1
    )
    gaussians_viz.far_plane_slider = gaussians_viz.server.add_gui_slider(
        "Far", min=30.0, max=1000.0, step=10.0, initial_value=1000.0
    )

    gaussians_viz.show_train_camera = gaussians_viz.server.add_gui_checkbox(
        "Show Train Camera", initial_value=False
    )

    gaussians_viz.fps = gaussians_viz.server.add_gui_text("FPS", initial_value="-1", disabled=True)

    @gaussians_viz.show_train_camera.on_update
    def _(_):
        gaussians_viz.need_update = True

    @gaussians_viz.resolution_slider.on_update
    def _(_):
        gaussians_viz.need_update = True

    @gaussians_viz.near_plane_slider.on_update
    def _(_):
        gaussians_viz.need_update = True

    @gaussians_viz.far_plane_slider.on_update
    def _(_):
        gaussians_viz.need_update = True

    @gaussians_viz.pause_training_button.on_click
    def _(_):
        gaussians_viz.pause_training = not gaussians_viz.pause_training
        gaussians_viz.train_viewer_update_period_slider.disabled = not gaussians_viz.pause_training
        gaussians_viz.pause_training_button.name = (
            "Resume Training" if gaussians_viz.pause_training else "Pause Training"
        )

    @gaussians_viz.reset_view_button.on_click
    def _(_):
        gaussians_viz.need_update = True
        for client in gaussians_viz.server.get_clients().values():
            client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array(
                [0.0, -1.0, 0.0]
            )

    gaussians_viz.c2ws = []
    gaussians_viz.camera_infos = []

    @gaussians_viz.resolution_slider.on_update
    def _(_):
        gaussians_viz.need_update = True

    @gaussians_viz.server.on_client_connect
    def _(client: viser.ClientHandle):
        @client.camera.on_update
        def _(_):
            gaussians_viz.need_update = True

    gaussians_viz.debug_idx = 0
    
    #根据当前显示设置(分辨率和相机信息)来获取一个新的Camera对象,用于渲染
    def get_current_cam(camera):
        W = gaussians_viz.resolution_slider.value
        H = int(gaussians_viz.resolution_slider.value/camera.aspect)
        w2c = get_w2c(camera)

        vfov_deg = camera.fov
        hfov_deg = vfov_to_hfov(vfov_deg, H, W)
        FoVx = np.deg2rad(hfov_deg)
        FoVy = np.deg2rad(vfov_deg)
        fx = fov2focal(FoVx, W)
        fy = fov2focal(FoVy, H) #根据角度和宽度计算焦点距离
        cx = W // 2
        cy = H // 2 #计算图像中心点
        T = torch.from_numpy(w2c)
        current_cam = Camera.init_from_gui(
            uid=-1,
            T=T,
            FoVx=FoVx,
            FoVy=FoVy,
            fx=fx,
            fy=fy,
            cx=cx,
            cy=cy,
            H=H,
            W=W,
        )
        current_cam.update_RT(T[0:3, 0:3], T[0:3, 3])
        return current_cam

    @torch.no_grad()

    def update():
        if gaussians_viz.need_update:#如果不需要更新,函数直接返回,跳过渲染过程
            return
            for client in gaussians_viz.server.get_clients().values():
                camera = client.camera
                try:
                    custom_cam = get_current_cam(camera)

                    start_cuda = torch.cuda.Event(enable_timing=True)
                    end_cuda = torch.cuda.Event(enable_timing=True)
                    start_cuda.record()

                    render_pkg = render(gaussians_viz.custom_cam, gaussians_viz.mapper.gaussians, gaussians_viz.mapper.pipeline_params, gaussians_viz.mapper.background)
                    
                    end_cuda.record()#计算单次渲染的性能
                    torch.cuda.synchronize()
                    interval = start_cuda.elapsed_time(end_cuda)/1000.

                    out = render_pkg["render"].cpu().detach().numpy().astype(np.float32)
                except RuntimeError as e:
                    print(e)
                    interval = 1
                    continue #等待GPU操作完成并获取结果。如果在渲染过程中遇到错误,打印错误信息并设定一个渲染时间间隔为1
                client.set_background_image(out, format="jpeg")
                gaussians_viz.debug_idx += 1
                # if gaussians_viz.debug_idx % 100 == 0:
                #     cv2.imwrite(
                #         f"./tmp/viewer/debug_{gaussians_viz.debug_idx}.png",
                #         cv2.cvtColor(out, cv2.COLOR_RGB2BGR),
                #     )

            gaussians_viz.render_times.append(interval)
            gaussians_viz.fps.value = f"{1.0 / np.mean(gaussians_viz.render_times):.3g}"

    while True:
        update()
        time.sleep(1e-3)

五、framed.py

import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
import numpy as np
import cv2

from .lightglue import LightGlue, SuperPoint, DISK
from .depth_anything_v2.dpt import DepthAnythingV2
from .depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
from ultralytics import YOLO

def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
    """Normalize the image tensor and reorder the dimensions."""
    if image.ndim == 3:#对于三维图像,将维度从 HxWxC 转换为 CxHxW
        image = image.transpose((2, 0, 1))  # HxWxC to CxHxW
    elif image.ndim == 2: #对于二维图像,增加通道维度
        image = image[None]  # Add channel axis
    else:
        raise ValueError(f"Not an image: {image.shape}")
    return torch.tensor(image / 255.0, dtype=torch.float) #图像归一化

#将深度图归一化到 [0, 255] 范围
def depth_to_color(depth: torch.Tensor) -> np.ndarray:
    """Convert a depth map to a color image."""
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0

    #将深度图转化为伪彩色图像
    depth = cv2.applyColorMap(depth.cpu().numpy().astype(np.uint8), cv2.COLORMAP_INFERNO)
    return depth

class VisualFrontend:
    def __init__(self, config, window, calib):
        """
        Initialize the VisualFrontend with configuration, window, and calibration data.
        
        Args:
            config (dict): Configuration settings.
            window (object): Sliding window object.
            calib (object): Calibration data.
        """
        self.config = config
        self.window = window
        self.calib = calib
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.pre_images = None
        
        self.init_lightglue_model()
        
        if self.config['use_depth_estimation']:
            self.init_depthanything_model()
            
        if self.config['use_segmentation']:
            self.init_yolo_model()
        
    def __call__(self, frame, viz=None):
        """
        Process the current frame and update the sliding window.
        
        Args:
            frame (dict): Current frame data.
            viz (object, optional): Visualization object. Defaults to None.
        """
        cur_images = []
        matches = {}
      
        # Intra-frame matching
        intra_matches = []
        for i in range(len(frame['camera']['data'])):
            cur_image = numpy_image_to_torch(frame['camera']['data'][i])
            timestamp = frame['camera']['timestamp'][i]
            cur_feats = self.extract_features(cur_image)
            
            if self.pre_images is not None:
                match01 = self.match_features(self.pre_images[i]['feats'], cur_feats)
                intra_matches.append(match01)
                
            cur_images.append({'timestamp': timestamp, 'image': frame['camera']['data'][i], 'feats': cur_feats})
            
            # Predict depth
            if self.config['use_depth_estimation']:
                depth_pred = self.single_depth_estimation(frame['camera']['data'][i])
                cur_images[i]['depth_pred'] = depth_pred
            
            # Predict segmentation mask
            if self.config['use_segmentation']:
                seg_mask = self.segment_image(frame['camera']['data'][i])
                cur_images[i]['seg_mask'] = seg_mask
            
        matches["intra-frame"] = intra_matches
        
        # Inter-frame matching
        inter_matches = []
        for pair in self.calib.stereo_pairs:
            cam_id0 = int(pair[0][-1])
            cam_id1 = int(pair[1][-1])
            left_feats = cur_images[cam_id0]['feats']
            right_feats = cur_images[cam_id1]['feats']
            match01 = self.match_features(left_feats, right_feats)
            inter_matches.append(match01)
        matches["inter-frame"] = inter_matches
        
        # Push the frame to the window
        self.window.feed_frame(frame, cur_images, matches)
        self.pre_images = cur_images
        
        if self.config['publish_debug']:
            self.publish(viz)

    #初始化模型,用于特征匹配
    def init_lightglue_model(self):
        """Initialize the LightGlue model for feature extraction and matching."""
        self.extractor = SuperPoint(max_num_keypoints=self.config['max_keypoints'])
        self.extractor.load_state_dict(torch.load(f'./checkpoints/superpoint_v1.pth'))
        self.extractor.eval().to(self.device)
        self.matcher = LightGlue(features=self.config['keypoint_type']).eval().to(self.device)

    #用于深度估计
    def init_depthanything_model(self):
        """Initialize the DepthAnythingV2 model for depth estimation."""
        model_configs = {
            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
            'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
            'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
            'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
        }

        encoder = 'vits'  # Options: 'vits', 'vitb', 'vitl', 'vitg'
        self.depth_anything = DepthAnythingV2(**model_configs[encoder])
        self.depth_anything.load_state_dict(torch.load(f'./checkpoints/depth_anything_v2_{encoder}.pth'))
        self.depth_anything.to(self.device).eval()
        
    #初始化yolo模型
    def init_yolo_model(self):
        """Initialize the YOLO model for segmentation and tracking."""
        self.seg_model = YOLO("./checkpoints/yolov8m-seg.pt")
        self.track_model = YOLO("./checkpoints/yolov8n.pt")

    #特征提取部分
    def extract_features(self, image):
        """
        Extract features from the given image using the LightGlue model.
        
        Args:
            image (torch.Tensor): Input image.
        
        Returns:
            dict: Extracted features.
        """
        with torch.no_grad():
            return self.extractor.extract(image.to(self.device))

    #特征匹配
    def match_features(self, feats0, feats1):
        """
        Match features between two sets of features.
        
        Args:
            feats0 (dict): Features from the first image.
            feats1 (dict): Features from the second image.
        
        Returns:
            dict: Matched features.
        """
        with torch.no_grad():
            return self.matcher({"image0": feats0, "image1": feats1})

    def track_features(self, image0, image1, keypoints0):
        """Track features between two images. Not implemented."""
        raise NotImplementedError

    #深度估计
    def single_depth_estimation(self, image):
        """
        Perform single depth estimation on the given image.
        
        Args:
            image (np.ndarray): Input image.
        
        Returns:
            torch.Tensor: Estimated depth map.
        """
        transform = Compose([
            Resize(
                width=518,
                height=518,
                resize_target=False,
                keep_aspect_ratio=True,
                ensure_multiple_of=14,
                resize_method='lower_bound',
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            PrepareForNet(),
        ])
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
        h, w = image.shape[:2]
        image = transform({'image': image})['image']
        image = torch.from_numpy(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            depth = self.depth_anything(image)

            #利用双线性插值调整深度图到原始图像的尺寸
            depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
        
            return depth
        
    def track_image(self, image):
        """Track features in the given image. Not implemented."""
        raise NotImplementedError
        
    def segment_image(self, image):
        """
        Perform image segmentation using the YOLO model.
        
        Args:
            image (np.ndarray): Input image.
        
        Returns:
            np.ndarray: Segmentation mask.
        """
        if image.ndim == 2: #如果输入图像是二维的,将其重复三次以形成三通道图像
            image = np.repeat(image[:, :, None], 3, axis=2)
        result = self.seg_model(image, verbose=False)
        
        return result[0].plot()
    
    def publish(self, viz):
        """
        Publish the processed images for visualization.
        
        Args:
            viz (object): Visualization object.
        """
        if viz is None:
            return
        
        merge_image = None
        for i in range(len(self.pre_images)): #遍历所有预处理图像,添加摄像头名称标签
            cam_name = f"cam{i}"
            image = self.pre_images[i]['image'].copy()
            if image.ndim == 2:
                image = np.repeat(image[:, :, None], 3, axis=2)
            cv2.putText(image, cam_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            if not self.config['only_plot_image']:
                if self.config['use_depth_estimation']:
                    depth = depth_to_color(self.pre_images[i]['depth_pred'])
                    image = np.vstack((image, depth))
                    
                if self.config['use_segmentation']:
                    seg_mask = self.pre_images[i]['seg_mask']
                    image = np.vstack((image, seg_mask))

            if merge_image is None:
                merge_image = image
            else:
                merge_image = np.hstack((merge_image, image)) #将处理后的图像水平拼接成一个大图像,并显示在可视化对象上
        
        timestamp = int(self.pre_images[0]['timestamp'] * 1e3)
        viz.show_image('Frontend', timestamp, merge_image, True)

六、fronted.py

import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
import numpy as np
import cv2

from .lightglue import LightGlue, SuperPoint, DISK
from .depth_anything_v2.dpt import DepthAnythingV2
from .depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
from ultralytics import YOLO

def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
    """Normalize the image tensor and reorder the dimensions."""
    if image.ndim == 3:
        image = image.transpose((2, 0, 1))  # HxWxC to CxHxW
    elif image.ndim == 2:
        image = image[None]  # Add channel axis
    else:
        raise ValueError(f"Not an image: {image.shape}")
    return torch.tensor(image / 255.0, dtype=torch.float)

def depth_to_color(depth: torch.Tensor) -> np.ndarray:
    """Convert a depth map to a color image."""
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth = cv2.applyColorMap(depth.cpu().numpy().astype(np.uint8), cv2.COLORMAP_INFERNO)
    return depth

class VisualFrontend:
    def __init__(self, config, window, calib):
        """
        Initialize the VisualFrontend with configuration, window, and calibration data.
        
        Args:
            config (dict): Configuration settings.
            window (object): Sliding window object.
            calib (object): Calibration data.
        """
        self.config = config
        self.window = window
        self.calib = calib
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.pre_images = None
        
        self.init_lightglue_model()
        
        if self.config['use_depth_estimation']:
            self.init_depthanything_model()
            
        if self.config['use_segmentation']:
            self.init_yolo_model()
        
    def __call__(self, frame, viz=None):
        """
        Process the current frame and update the sliding window.
        
        Args:
            frame (dict): Current frame data.
            viz (object, optional): Visualization object. Defaults to None.
        """
        cur_images = []
        matches = {}
      
        # Intra-frame matching
        intra_matches = []
        for i in range(len(frame['camera']['data'])):
            cur_image = numpy_image_to_torch(frame['camera']['data'][i])
            timestamp = frame['camera']['timestamp'][i]
            cur_feats = self.extract_features(cur_image)
            
            if self.pre_images is not None:
                match01 = self.match_features(self.pre_images[i]['feats'], cur_feats)
                intra_matches.append(match01)
                
            cur_images.append({'timestamp': timestamp, 'image': frame['camera']['data'][i], 'feats': cur_feats})
            
            # Predict depth
            if self.config['use_depth_estimation']:
                depth_pred = self.single_depth_estimation(frame['camera']['data'][i])
                cur_images[i]['depth_pred'] = depth_pred
            
            # Predict segmentation mask
            if self.config['use_segmentation']:
                seg_mask = self.segment_image(frame['camera']['data'][i])
                cur_images[i]['seg_mask'] = seg_mask
            
        matches["intra-frame"] = intra_matches
        
        # Inter-frame matching
        inter_matches = []
        for pair in self.calib.stereo_pairs:
            cam_id0 = int(pair[0][-1])
            cam_id1 = int(pair[1][-1])
            left_feats = cur_images[cam_id0]['feats']
            right_feats = cur_images[cam_id1]['feats']
            match01 = self.match_features(left_feats, right_feats)
            inter_matches.append(match01)
        matches["inter-frame"] = inter_matches
        
        # Push the frame to the window
        self.window.feed_frame(frame, cur_images, matches)
        self.pre_images = cur_images
        
        if self.config['publish_debug']:
            self.publish(viz)

    def init_lightglue_model(self):
        """Initialize the LightGlue model for feature extraction and matching."""
        self.extractor = SuperPoint(max_num_keypoints=self.config['max_keypoints'])
        self.extractor.load_state_dict(torch.load(f'./checkpoints/superpoint_v1.pth'))
        self.extractor.eval().to(self.device)
        self.matcher = LightGlue(features=self.config['keypoint_type']).eval().to(self.device)
        
    def init_depthanything_model(self):
        """Initialize the DepthAnythingV2 model for depth estimation."""
        model_configs = {
            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
            'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
            'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
            'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
        }

        encoder = 'vits'  # Options: 'vits', 'vitb', 'vitl', 'vitg'
        self.depth_anything = DepthAnythingV2(**model_configs[encoder])
        self.depth_anything.load_state_dict(torch.load(f'./checkpoints/depth_anything_v2_{encoder}.pth'))
        self.depth_anything.to(self.device).eval()
        
    def init_yolo_model(self):
        """Initialize the YOLO model for segmentation and tracking."""
        self.seg_model = YOLO("./checkpoints/yolov8m-seg.pt")
        self.track_model = YOLO("./checkpoints/yolov8n.pt")

    def extract_features(self, image):
        """
        Extract features from the given image using the LightGlue model.
        
        Args:
            image (torch.Tensor): Input image.
        
        Returns:
            dict: Extracted features.
        """
        with torch.no_grad():
            return self.extractor.extract(image.to(self.device))

    def match_features(self, feats0, feats1):
        """
        Match features between two sets of features.
        
        Args:
            feats0 (dict): Features from the first image.
            feats1 (dict): Features from the second image.
        
        Returns:
            dict: Matched features.
        """
        with torch.no_grad():
            return self.matcher({"image0": feats0, "image1": feats1})

    def track_features(self, image0, image1, keypoints0):
        """Track features between two images. Not implemented."""
        raise NotImplementedError

    def single_depth_estimation(self, image):
        """
        Perform single depth estimation on the given image.
        
        Args:
            image (np.ndarray): Input image.
        
        Returns:
            torch.Tensor: Estimated depth map.
        """
        transform = Compose([
            Resize(
                width=518,
                height=518,
                resize_target=False,
                keep_aspect_ratio=True,
                ensure_multiple_of=14,
                resize_method='lower_bound',
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            PrepareForNet(),
        ])
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
        h, w = image.shape[:2]
        image = transform({'image': image})['image']
        image = torch.from_numpy(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            depth = self.depth_anything(image)
            depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
        
            return depth
        
    def track_image(self, image):
        """Track features in the given image. Not implemented."""
        raise NotImplementedError
        
    def segment_image(self, image):
        """
        Perform image segmentation using the YOLO model.
        
        Args:
            image (np.ndarray): Input image.
        
        Returns:
            np.ndarray: Segmentation mask.
        """
        if image.ndim == 2:
            image = np.repeat(image[:, :, None], 3, axis=2)
        result = self.seg_model(image, verbose=False)
        
        return result[0].plot()
    
    def publish(self, viz):
        """
        Publish the processed images for visualization.
        
        Args:
            viz (object): Visualization object.
        """
        if viz is None:
            return
        
        merge_image = None
        for i in range(len(self.pre_images)):
            cam_name = f"cam{i}"
            image = self.pre_images[i]['image'].copy()
            if image.ndim == 2: #如果图像是灰度图像,将其扩展为 RGB 图像
                image = np.repeat(image[:, :, None], 3, axis=2)
            cv2.putText(image, cam_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            if not self.config['only_plot_image']:
                if self.config['use_depth_estimation']:
                    depth = depth_to_color(self.pre_images[i]['depth_pred'])
                    image = np.vstack((image, depth))
                    
                if self.config['use_segmentation']:
                    seg_mask = self.pre_images[i]['seg_mask']
                    image = np.vstack((image, seg_mask))

            if merge_image is None:
                merge_image = image
            else:
                merge_image = np.hstack((merge_image, image))
        
        timestamp = int(self.pre_images[0]['timestamp'] * 1e3)
        viz.show_image('Frontend', timestamp, merge_image, True) #使用 viz 对象显示合并后的图像,标记为 Frontend

七、slam.py

from .utils.logging_utils import Log
from core.frontend import VisualFrontend
from core.droid import load_droid_model, MotionFilter
from core.window import SlideWindow
from core.backend import Backend

from pyfoxglove import Visualizer

import torch.multiprocessing as mp

class SLAM:
    def __init__(self, config, calib):
        self.config = config
        self.calib = calib
        self.droid_net = load_droid_model('./checkpoints/droid.pth') #加载预训练的Droid网络模型
        
        self.window = SlideWindow(self.config['SlideWindow'], self.calib) #创建滑动窗口对象
        self.motion_filter = MotionFilter(self.config['MotionFilter'], self.window, self.droid_net) #创建运动过滤器对象
        self.frontend = VisualFrontend(self.config['Frontend'], self.window, self.calib) #前端-负责输入帧并提取特征
        self.backend = Backend(self.config["Backend"], self.window, self.calib) #后端-负责优化和更新地图
        
        self.init_viz(config['Visualizer'])
        
    def finalize(self):
        self.backend.finalize()
        
    def feed_frame(self, frame):
        if not self.motion_filter(frame, self.viz): #对帧进行过滤
            return
        self.frontend(frame, self.viz)
        self.window(self.viz)
        self.backend()
        
    def init_viz(self, config):
        if config['enable']:
            self.viz = Visualizer(config['port_num'], config['sleep_time'])
        else:
            self.viz = None

八、window.py

from core.frame import KeyFrame

import numpy as np
import torch
import gtsam
from pyslamcpts import SparseMap

from .utils.type_utils import numpy_pose_to_torch, torch_pose_to_numpy, gtsam_pose_to_numpy

class WindowInfoTable:
    def __init__(self, window_size, calib):
        """
        Initialize the WindowInfoTable with the given window size and calibration data.
        
        Args:
            window_size (int): The size of the sliding window.
            calib: The calibration data containing camera and sensor information.
        """
        self.calib = calib
        self.window_size = window_size
        self._initialize_parameters()
        self._initialize_calibration_data()
    
    def _initialize_parameters(self):
        """
        Initialize the parameters for storing poses, disparities, feature maps, networks, and inputs.
        """
        wd, ht = self.calib.cam_size
        cam_num = len(self.calib.sensors["CAMERA"])

        self.poses = torch.zeros(self.window_size + 1, 7, dtype=torch.float, device="cuda:0")  # Tbw
        self.disps = torch.ones(self.window_size + 1, cam_num, ht // 8, wd // 8, device="cuda", dtype=torch.float)
        self.fmaps = torch.zeros(self.window_size + 1, cam_num, 128, ht // 8, wd // 8, dtype=torch.half, device="cuda:0")
        self.nets = torch.zeros(self.window_size + 1, cam_num, 128, ht // 8, wd // 8, dtype=torch.half, device="cuda:0")
        self.inps = torch.zeros(self.window_size + 1, cam_num, 128, ht // 8, wd // 8, dtype=torch.half, device="cuda:0")
        self.intrinsics = torch.zeros(cam_num, 4, dtype=torch.float, device="cuda:0") #相机内参
        self.extrinsics = torch.zeros(cam_num, 7, dtype=torch.float, device="cuda:0") #相机外参
    
    def _initialize_calibration_data(self):
        """
        Initialize the calibration data for each camera, setting the extrinsics and intrinsics.
        """
        for cam_name, cam_info in self.calib.sensors["CAMERA"].items():
            cam_idx = int(cam_name[-1])
            self._set_extrinsics(cam_idx, cam_info.extrinsic)
            self._set_intrinsics(cam_idx, cam_info.opt_camera_params)
    
    def _set_extrinsics(self, cam_idx, extrinsic):
        """
        Set the extrinsics for a given camera index.

        Args:
            cam_idx (int): The index of the camera.
            extrinsic: The extrinsic parameters of the camera.
        """
        self.extrinsics[cam_idx] = numpy_pose_to_torch(extrinsic).to("cuda:0")
    
    def _set_intrinsics(self, cam_idx, opt_camera_params):
        """
        Set the intrinsics for a given camera index.

        Args:
            cam_idx (int): The index of the camera.
            opt_camera_params: The intrinsic parameters of the camera.
        """
        self.intrinsics[cam_idx] = torch.tensor([
            opt_camera_params["fx"] / 8, 
            opt_camera_params["fy"] / 8, 
            opt_camera_params["cx"] / 8, 
            opt_camera_params["cy"] / 8
        ], dtype=torch.float, device="cuda:0")
    
    def roll(self, shift=-1):
        """
        Roll the stored parameters by a specified shift.

        Args:
            shift (int): The number of positions to roll the parameters. Default is -1.
        """
        self._roll_parameters(shift)
    
    def _roll_parameters(self, shift):
        """
        Roll the poses, disparities, feature maps, networks, and inputs by the given shift.

        Args:
            shift (int): The number of positions to roll the parameters.
        """
        self.poses = torch.roll(self.poses, shifts=shift, dims=0)
        self.disps = torch.roll(self.disps, shifts=shift, dims=0)
        self.fmaps = torch.roll(self.fmaps, shifts=shift, dims=0)
        self.nets = torch.roll(self.nets, shifts=shift, dims=0)
        self.inps = torch.roll(self.inps, shifts=shift, dims=0)
    
    def replace(self, idx, shift=1):
        """
        Replace the parameters at a given index with the parameters from a shifted index.

        Args:
            idx (int): The index of the parameters to replace.
            shift (int): The shift amount to determine the source index. Default is 1.
        """
        self._replace_parameters(idx, shift)
    
    def _replace_parameters(self, idx, shift):
        """
        Replace the poses, disparities, feature maps, networks, and inputs at a given index 
        with those from a shifted index.

        Args:
            idx (int): The index of the parameters to replace.
            shift (int): The shift amount to determine the source index.
        """
        self.poses[idx - shift] = self.poses[idx]
        self.disps[idx - shift] = self.disps[idx]
        self.fmaps[idx - shift] = self.fmaps[idx]
        self.nets[idx - shift] = self.nets[idx]
        self.inps[idx - shift] = self.inps[idx]


class SlideWindow:
    def __init__(self, config, calib):
        """Initialize the sliding window with configuration and calibration data."""
        self.config = config
        self.calib = calib
        self.window_size = self.config['window_size']
        self.keyframes = []
        self.window_info_table = WindowInfoTable(self.window_size, self.calib)
        self.is_initialized = False
        self.next_kf_id = 0
        self._init_sparse_map()
        
    def _init_sparse_map(self): #初始化稀疏地图,设置相机内外参
        self.sparse_map = SparseMap()
        calibs = []
        for cam_name, cam_info in self.calib.sensors["CAMERA"].items():
            calibs.append(cam_info.extrinsic)
        self.sparse_map.set_calibration(calibs)

    #处理新帧,创建关键帧并更新稀疏地图和窗口信息表
    def feed_frame(self, raw_info, images_info, match_info):
        """Feed a new frame into the sliding window."""
        new_kf = self._create_keyframe(raw_info, images_info, match_info)
        self._update_sparse_map(new_kf)
        self.keyframes.append(new_kf)
        self._update_window_info_table()
        self.next_kf_id += 1

    def _create_keyframe(self, raw_info, images_info, match_info):
        """Create a new keyframe with provided information."""
        new_kf = KeyFrame(self.calib, self.next_kf_id)
        new_kf.load(raw_info, images_info, match_info)
        return new_kf

    def _update_sparse_map(self, new_kf): #更新稀疏地图
        """Update the sparse map with the new keyframe data."""
        kf_id, imgs, keypoints, bearings = new_kf.frame_info()
        self.sparse_map.add_keyframe(kf_id, imgs, keypoints, bearings)

        pre_kf_id, kf_id, matches = new_kf.intra_matches_info()
        if kf_id is not None:
            self.sparse_map.add_intra_matches(pre_kf_id, kf_id, matches)
        
        kf_id, stereo_pairs, matches = new_kf.inter_matches_info()
        if kf_id is not None:
            self.sparse_map.add_inter_matches(kf_id, stereo_pairs, matches)

    def _update_window_info_table(self): #存储最新网络信息
        """Update the window information table with the latest network info."""
        idx = self.get_frame_size() - 1
        self.window_info_table.fmaps[idx] = self.net_info[0]
        self.window_info_table.nets[idx] = self.net_info[1]
        self.window_info_table.inps[idx] = self.net_info[2]

    def __call__(self, viz=None):
        """Main processing function for the sliding window."""
        if not self.is_initialized:
            self.init_window_by_gt()
            
        else:
            self.init_new_state()

        if self.config.get('publish_debug'):
            self.publish(viz)

        self.slide_frame()

    def init_window_by_gt(self): #使用地面真实数据初始化滑动窗口
        """Initialize the sliding window with ground truth data."""
        if self.get_frame_size() > self.window_size:
            self.is_initialized = True
            for i, kf in enumerate(self.keyframes):
                if i != 0:
                    kf.make_preintegration(self.keyframes[i-1])
                self._update_window_info_table_pose(i, kf)
                self.sparse_map.update_keyframe_pose(kf.kf_id, gtsam_pose_to_numpy(kf.Twb))
            self.sparse_map.triangulate()

    def _update_window_info_table_pose(self, i, kf): #更新姿态数据
        """Update the poses in the window information table."""
        Twb = kf.gt_info['pose']
        Tbw = np.linalg.inv(Twb)
        self.window_info_table.poses[i] = numpy_pose_to_torch(Tbw).to("cuda:0")
        kf.Twb = gtsam.Pose3(gtsam.Rot3(Twb[:3, :3]), gtsam.Point3(Twb[:3, 3]))
        kf.vwb = kf.gt_info['velocity']
        kf.imu_bias = gtsam.imuBias.ConstantBias(np.array(kf.gt_info['bias_accel']), np.array(kf.gt_info['bias_gyro']))

    def slide_frame(self):
        """Slide the window to remove the oldest frame when it exceeds the window size."""
        if self.get_frame_size() > self.window_size:
            self.sparse_map.remove_keyframe(self.keyframes[0].kf_id)
            self.keyframes.pop(0)
            self.window_info_table.roll(-1)

    def init_new_state(self):
        """Initialize the new state for the current keyframe."""
        cur_kf = self.keyframes[-1]
        cur_kf.make_preintegration(self.keyframes[-2])
        self._update_window_info_table_pose(self.get_frame_size() - 1, cur_kf)
        self.sparse_map.update_keyframe_pose(cur_kf.kf_id, gtsam_pose_to_numpy(cur_kf.Twb))
        self.sparse_map.triangulate()

    def get_frame(self, idx):
        """Get a specific keyframe by index."""
        return self.keyframes[idx]

    def get_all_frames(self):
        """Get all keyframes in the sliding window."""
        return self.keyframes

    def get_window_size(self):
        """Get the window size."""
        return self.window_size

    def get_frame_size(self):
        """Get the current number of frames in the sliding window."""
        return len(self.keyframes)

    def clear(self):
        """Clear all keyframes from the sliding window."""
        self.keyframes = []
        self.next_kf_id = 0
        self.is_initialized = False

    def save_temp_droid(self, gmap, net, inp):
        """Save temporary network information."""
        self.net_info = [gmap, net, inp]
    
    def publish(self, viz):
        """Publish window information."""
        if viz is None or self.get_frame_size() == 0:
            return

        # Publish camera extrinsics
        for cam_name, cam_info in self.calib.sensors["CAMERA"].items():
            extrinsics = cam_info.extrinsic
            viz.show_pose(f'ex_{cam_name}', 0, extrinsics, 'body', cam_name)
        
        # Publish the sparse map
        if self.is_initialized and self.config['publish_matches']:
            self._publish_sparse_map(viz)
        
        # Publish the poses
        if self.is_initialized:
            self._publish_poses(viz)
        
        # Publish the ground truth info
        self._publish_gt_info(viz)

    def _publish_sparse_map(self, viz): #发布关键点图像和立体匹配关键点图像
        kpts_img = [self.sparse_map.draw_flow(self.keyframes[-1].kf_id, cam_id) 
                    for cam_id in range(len(self.keyframes[-1].camera))]
        
        merge_image = np.concatenate(kpts_img, axis=1)
        timestamp = int(self.keyframes[-1].timestamp * 1e3)
        viz.show_image('Keypoints', timestamp, merge_image, True)
        
        merge_image = self.sparse_map.draw_stereo_keypoint(self.keyframes[-1].kf_id)
        viz.show_image('stereo_kpts', timestamp, merge_image, True)

    def _publish_poses(self, viz): #发布关键帧的路径和预测的点云数据
        poses = [gtsam_pose_to_numpy(frame.Twb) for frame in self.keyframes]
        timestamp = int(self.keyframes[-1].timestamp * 1e3)
        new_pose = gtsam_pose_to_numpy(self.keyframes[-1].Twb)
        
        viz.show_pose('pose_body', timestamp, new_pose, 'world', 'body')
        viz.show_path('path_body', timestamp, poses, 'world')
        
        if self.config['publish_pred_pcd']:
            pts, clr = self.keyframes[-1].get_pcd_from_pred_depth(0)
            if pts is not None:
                viz.show_pointcloud("cam0_pcd", timestamp, pts, clr, "cam0", 1)
        
        if self.config['publish_sparse_pcd']:
            pts = self.sparse_map.get_world_points()
            if pts is not None:
                viz.show_pointcloud("sparse_world_pcd", timestamp, pts, [], "world", 1)

    def _publish_gt_info(self, viz): #发布真实姿态和信息
        gt_poses = [frame.gt_info['pose'] for frame in self.keyframes]
        timestamp = int(self.keyframes[-1].gt_info['timestamp'] * 1e3)
        new_pose = self.keyframes[-1].gt_info['pose']
        
        viz.show_pose('pose_body_gt', timestamp, new_pose, 'world', 'body_gt')
        viz.show_path('path_body_gt', timestamp, gt_poses, 'world')
  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值