sqn代码理解

学习目标:

  • 理解sqn代码如何跑起来
  • 对主要代码进行学习理解

学习内容:

SQN: Weakly-Supervised Semantic Segmentation of Large-Scale 3D Point Clouds (ECCV2022)

(1) Setup

This code has been tested with Python 3.5, Tensorflow 1.11, CUDA 9.0 and cuDNN 7.4.1 on Ubuntu 16.04/Ubuntu 18.04.
这里说明了我们的一个实验代码运行所需要具备的条件,包括python、cuda的版本,以及在Ubuntu的下载版本。版本建议严格按照所给出的来使用,以防后期在跑实验的过程中出现其他因版本而出现的问题。

git clone --depth=1 https://github.com/QingyongHu/SQN && cd SQN
这里是克隆代码,创建SQN,进入到SQN目录上。
conda create -n sqn python=3.5
创建一个conda环境命名为sqn,用python 3.5的版本去创建。
source activate sqn
激活sqn环境
pip install -r helper_requirements.txt
从helper_requirements.txt安装环境需要的各种包和库。
sh compile_op.sh
执行.sh文件。如果遇到无法执行的情况需要打开文件逐条执行。

(2) Training (Semantic3D as example)

Start training with weakly supervised setting:

python main_Semantic3D.py --mode train --gpu 0 --labeled_point 0.1%

运行.py文件,训练模式,使用gpu 0,用0.1%标注的点。
这里我想要集中去理解下main_Semantic3D.py(因为后续可能会需要进行修改创新),放入部分主要代码。


from SQN import Network
from tester_Semantic3D import ModelTester
from tool import Plot
from tool import DataProcessing as DP
import tensorflow as tf
import numpy as np
导入了SQN模块中的Network类,用于神经网络的构建和训练。
导入了tester_Semantic3D模块中的ModelTester类,用于模型测试。
导入了tool模块中的Plot和DataProcessing类,用于数据处理和绘图。
导入了tensorflow和numpy库,用于深度学习和数值计算。

class Semantic3D:
    def __init__(self, labeled_point, gen_pseudo, retrain):
    类的初始化方法__init__设置了数据集路径、类别标签、训练/验证/测试文件列表等。
    
    def load_sub_sampled_clouds(self, sub_grid_size, labeled_point, retrain): 
    load_sub_sampled_clouds方法用于加载子采样的点云数据,并根据给定的标注点百分比进行随机标注。
    
    def get_batch_gen(self, split):
     get_batch_gen方法生成用于训练、验证和测试的数据批次。
     
    def get_tf_mapping(self):
        def tf_map(batch_xyz, batch_features, batch_labels, batch_pc_idx, batch_cloud_idx, batch_xyz_anno,batch_label_anno):
     get_tf_mapping方法定义了TensorFlow操作,用于数据增强和处理。
     
    def tf_augment_input(inputs):
    tf_augment_input静态方法用于对输入数据进行增强,旋转、缩放和添加噪声。
    
    def init_input_pipeline(self):
    init_input_pipeline方法初始化TensorFlow数据管道,用于批量处理和数据增强。

在主函数中, 使用argparse库解析命令行参数,如GPU编号、模式(训练、测试或可视化)、标注点百分比等。设置环境变量,指定CUDA设备和日志级别。创建Semantic3D类的实例,并初始化输入数据管道。 根据模式(训练、测试或可视化)执行相应的操作。

 if Mode == 'train':
        model = Network(dataset, cfg, FLAGS.retrain)
        model.train(dataset)
    elif Mode == 'test':
        cfg.saving = False
        model = Network(dataset, cfg)
如果模式为train,则创建Network类的实例并开始训练。
如果模式为test,则加载模型的快照,并使用ModelTester类进行评估。

(3)Evaluation:

python main_Semantic3D.py --mode test --gpu 0 --labeled_point 0.1%

  def evaluate(self, model, dataset, gen_pseudo=None, num_votes=100):

        # Smoothing parameter for votes
        test_smooth = 0.98

        # Initialise iterator with train data
        self.sess.run(dataset.test_init_op)

        if gen_pseudo:
            # Number of points per class in validation set
            val_proportions = np.zeros(model.config.num_classes, dtype=np.float32)
            i = 0
            for label_val in dataset.label_values:
                if label_val not in dataset.ignored_labels:
                    val_proportions[i] = np.sum(
                        [np.sum(labels == label_val) for labels in dataset.input_labels['test']])
                    i += 1

        # Test saving path
        saving_path = time.strftime('results/Log_%Y-%m-%d_%H-%M-%S', time.gmtime())
        test_path = join('test', saving_path.split('/')[-1])
        makedirs(test_path) if not exists(test_path) else None
        makedirs(join(test_path, 'predictions')) if not exists(join(test_path, 'predictions')) else None
        makedirs(join(test_path, 'probs')) if not exists(join(test_path, 'probs')) else None
   

evaluate 方法是评估模型的主要函数,它执行以下操作:

初始化 TensorFlow 数据迭代器。如果 gen_pseudo 为真,计算验证集上每个类别的点数比例。
创建保存预测结果的目录结构。 循环执行模型预测,直到达到指定的投票次数 num_votes。
如果发生 OutOfRangeError(表示迭代结束),则保存预测结果并计算混淆矩阵。
如果 gen_pseudo 为真,生成伪标签并保存到文件。重新初始化数据迭代器并继续下一个 epoch。

在这里插入图片描述

  • 5
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值