【代码阅读】VoteNet (推荐阅读源码,学到的东西很多)

本文深入解析VoteNet的源码,探讨3D点云处理中的目标检测技术。从代码结构、数据处理到训练过程详解,特别是Loss计算和NMS实现,提供对点云检测模型的深入理解。
摘要由CSDN通过智能技术生成

本文是针对VoteNet:Deep Hough Voting for 3D Object Detection in Point Clouds论文的源码的理解。VoteNet的解读可以参考我的另外一篇博客。具体的源码,可在github上下载

总的来说,代码写的非常优雅,我觉得光从代码的结构来看,也有很多可以借鉴的地方。所以本文先看一下代码的结构,然后再跟进去详解。

代码结构

train.py

# train.py
# 具体的内容可以看源码,这里只是记录一些代码要干的事情,很有借鉴意义的代码。

通过parser定义trian过程需要的参数
定义Log_dir和Dump_dir,并打开Log_dir,写入本次训练的parser中有关config的参数
定义一个写入log的函数,只是在trian.py中调用
加载dataset的config,加载dataset,在定义dataset时,使用train和eval两种模式,甚至还可以加入test模式,从而可以共用dataset的接口
定义worker_init_fn,加载dataloader
加载model,并将其放到nn.DataParallel中
加载criterion,由于loss计算越来越复杂,定义一个函数或者类
定义optimizer

def train_one_epoch():
	stat_dict = {
   }  # 定义一个储存中间过程变量的dict
    adjust_learning_rate(optimizer, EPOCH_CNT)  # 调整lr
    bnm_scheduler.step()  # decay BN momentum
    for batch_idx, batch_data_label in enumerate(TRAIN_DATALOADER):
    	前向计算
    	计算loss
    	反向传播
    	统计中间结果
    	展示中间结果

def eval_one_epoch():
	相比于train_one_epoch,增加计算最终统计量的部分例如AP,其他相同


def trian():
    for epoch in range(start_epoch, MAX_EPOCH):
    	log记录epoch的属性
    	np.random.seed()
        train_one_epoch()
        if EPOCH_CNT == 0 or EPOCH_CNT % 10 == 9: # Eval every 10 epochs
            loss = evaluate_one_epoch()
        # Save checkpoint
        save_dict = {
   'epoch': epoch+1, # after training one epoch, the start_epoch should be epoch+1
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    }
        try: # with nn.DataParallel() the net is added as a submodule of DataParallel
            save_dict['model_state_dict'] = net.module.state_dict()
        except:
            save_dict['model_state_dict'] = net.state_dict()
        torch.save(save_dict, os.path.join(LOG_DIR, 'checkpoint.tar'))

if __name__=='__main__':
    train(start_epoch)

值得学习的点:

  • log文件的写入写成一个函数,只在train.py中写入,写入的地方有了限制,好查找
  • 使用stat_dict作为储存中间变量的字典,贯穿整个过程,使得在trian.py中可以找到需要的所有中间过程,而且留出了中间过程与train.py之间交互的接口,不需要修改太多就可以加入新功能

训练数据处理

Sunrgbd的data是以matlab形式储存的,作者提供了从matlab中读出数据和label的函数:

  • extract_split.m:将数据集分割成训练集和验证集
  • extract_rgbd_data_v2.m:将v2版的label以txt形式储存,并且复制每个数据的depth,img和calib文件
  • extract_rgbd_data_v1.m:讲v1版的label以txt形式储存

在储存好上述数据之后,使用python sunrgbd_data.py --gen_v1_data进一步处理数据,将depth数据降采样,并构造votes的数据。

sunrgbd_data.py

def extract_sunrgbd_data(idx_filename, split, output_folder, num_point=20000,
    type_whitelist=DEFAULT_TYPE_WHITELIST,
    save_votes=False, use_v1=False, skip_empty_scene=True):
    """ Extract scene point clouds and 
    bounding boxes (centroids, box sizes, heading angles, semantic classes).
    Dumped point clouds and boxes are in upright depth coord.

    Args:
        idx_filename: a TXT file where each line is an int number (index)
        split: training or testing
        save_votes: whether to compute and save Ground truth votes.
        use_v1: use the SUN RGB-D V1 data
        skip_empty_scene: if True, skip scenes that contain no object (no objet in whitelist)

    Dumps:
        <id>_pc.npz of (N,6) where N is for number of subsampled points and 6 is
            for XYZ and RGB (in 0~1) in upright depth coord
        <id>_bbox.npy of (K,8) where K is the number of objects, 8 is for
            centroids (cx,cy,cz), dimension (l,w,h), heanding_angle and semantic_class
        <id>_votes.npz of (N,10) with 0/1 indicating whether the point belongs to an object,
            then three sets of GT votes for up to three objects. If the point is only in one
            object's OBB, then the three GT votes are the same.
    """
    dataset = sunrgbd_object('./sunrgbd_trainval', split, use_v1=use_v1)
    data_idx_list = [int(line.rstrip()) for line in open(idx_filename)]

    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    for data_idx in data_idx_list:
        print('------------- ', data_idx)
        objects = dataset.get_label_objects(data_idx)

        # Skip scenes with 0 object
        if skip_empty_scene and (len(objects)==0 or \
            len([obj for obj in objects if obj.classname in type_whitelist])==0):
                continue

        object_list = []
        for obj in objects:
            if obj.classname not in type_whitelist: continue
            obb = np.zeros((8))
            obb[0:3] = obj.centroid
            # Note that compared with that in data_viz, we do not time 2 to l,w.h
            # neither do we flip the heading angle
            obb[3:6] = np.array([obj.l,obj.w,obj.h])
            obb[6] = obj.heading_angle
            obb[7] = sunrgbd_utils.type2class[obj.classname]
            object_list.append(obb)
        if len(object_list)==0:
            obbs = np.zeros((0,8))
        else:
            obbs = np.vstack(object_list) # (K,8)

        pc_upright_depth = dataset.get_depth(data_idx)
        pc_upright_depth_subsampled = pc_util.random_sampling(pc_upright_depth, num_point)
		
		# 将降采样到50000个点写入_pc.npz,并将label写入bbox.npy
        np.savez_compressed(os.path.join(output_folder,'%06d_pc.npz'%(data_idx)),
            pc=pc_upright_depth_subsampled)
        np.save(os.path.join(output_folder, '%06d_bbox.npy'%(data_idx)), obbs)
       
        if save_votes:
            N = pc_upright_depth_subsampled.shape[0]
            point_votes = np.zeros((N,10)) # 3 votes and 1 vote mask 
            point_vote_idx = np.zeros((N)).astype(np.int32) # in the range of [0,2]
            indices = np.arange(N)
            # 对每个obj计算相对应的votes
            for obj in objects:
                if obj.classname not in type_whitelist: continue
                try:
                    # Find all points in this object's OBB
                    box3d_pts_3d = sunrgbd_utils.my_compute_box_3d(obj.centroid,
                        np.array([obj.l,obj.w,obj.h]), obj.heading_angle)
                    pc_in_box3d,inds = sunrgbd_utils.extract_pc_in_box3d(\
                        pc_upright_depth_subsampled, box3d_pts_3d)
                    # Assign first dimension to indicate it is in an object box
                    point_votes[inds,0] = 1
                    # Add the votes (all 0 if the point is not in any object's OBB)
                    votes = np.expand_dims(obj.centroid,0) - pc_in_box3d[:,0:3]
                    sparse_inds = indices[inds] # turn dense True,False inds to sparse number-wise inds
                    for i in range(len(sparse_inds)):
                        j = sparse_inds[i]
                        point_votes[j, int(point_vote_idx[j]*3+1):int((point_vote_idx[j]+1)*3+1)] = votes[i,:]
                        # Populate votes with the fisrt vote
                        if point_vote_idx[j] == 0:
                            point_votes[j,4:7] = votes[i,:]
                            point_votes[j,7:10] = votes[i,:]
                    point_vote_idx[inds] = np.minimum(2, point_vote_idx[inds]+1)
                except:
                    print('ERROR ----',  data_idx, obj.classname)
            np.savez_compressed(os.path.join(output_folder, '%06d_votes.npz'%(data_idx)),
                point_votes = point_votes)

Sunrgbd_detection_dataset.py

# Sunrgbd_detection_dataset.py
class SunrgbdDetectionVotesDataset(Dataset):
    def __init__(self, split_set='train', num_points=20000,
        use_color=False, use_height=False, use_v1=False,
        augment=False, scan_idx_list=None)
       
    def __len__(self):
        return len(self.scan_names)

    def __getitem__(self, idx):
   
		#先从文件中加载输入
		point_cloud = np.load(os.path.join(self.data_path, scan_name)+'_pc.npz')['pc'] # Nx6,(x, y, z, r, g, b)
        bboxes = np.load(os.path.join(self.data_path, scan_name
  • 32
    点赞
  • 140
    收藏
    觉得还不错? 一键收藏
  • 65
    评论
VoteNet是一个用于三维目标检测的深度学习模型。它通过将指定目标类别的特征向量转化为概率分布表示,来实现对目标的检测和分类。VoteNet代码详解如下。 VoteNet代码首先定义了一个VoteNet Class,其中包含了模型的网络结构。该网络结构由点云特征提取器、语义分割器、VoteNet层、检测层以及回归层组成。 点云特征提取器用于从点云数据中提取特征,常用的方法有VFE、PointNet等。语义分割器则用于对点云进行语义分割,将不同类别的点云分割开来。 VoteNet层是VoteNet模型的核心部分,它将每个目标体素划分为小的子体素,并为每个子体素生成一个特征描述符。这些特征描述符被编码为概率分布向量,用于表示每个类别的投票。 检测层通过利用VoteNet层生成的特征描述符,来对每个投票进行分类,以确定每个子体素所属的目标类别。 回归层则用于对目标的位姿信息进行回归,包括目标的位置、尺寸和姿态。 在实际应用中,我们可以使用VoteNet代码进行目标检测和分类任务。首先,我们需要准备点云数据和对应的标签,然后利用VoteNet模型对点云数据进行训练。训练过程中,通过计算损失函数来优化模型参数,并实现对目标的检测和分类。 在模型训练完成后,我们可以使用训练好的VoteNet模型对新的点云数据进行预测。通过将点云数据输入模型中,可以得到每个子体素的类别概率分布,从而实现目标的检测和分类。 综上所述,VoteNet代码详解主要涵盖了模型的网络结构以及训练和预测的过程。通过深入理解和实践VoteNet代码,我们可以更好地应用该模型进行三维目标检测任务。
评论 65
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值