Pointnet(part_seg)部件分割网络代码解析:test.py(三)

论文地址:PointNet网络
代码地址:https://github.com/charlesq34/pointnet
windows10环境配置可参考这篇文章:windows下运行pointnet(全)

网络图示:

在这里插入图片描述

一、训练模型

1)下载测试集
数据下载:PartAnnotation(注意该链接中对应的数据集有两部分,因此下载时两部分务必都要下载)
存放位置:…/part_seg/PartAnnotation。
HDF5文件解压放在pointnet-master/part_seg/hdf5_data。
2)训练模型
可参考Pointnet(part_seg)部件分割网络代码解析:train.py(二)文章,里面有代码详解,以及训练模型的获取。
3)测试集数据解析
在这里插入图片描述

二、test.py代码详解

1、选择第80个训练模型进行测试在这里插入图片描述

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='train_results/trained_models/epoch_80.ckpt', help='Model checkpoint path')
FLAGS = parser.parse_args()

2、setting

parser = argparse.ArgumentParser()
# 自行定义命令参数获取 model_path ,保存的训练模型
parser.add_argument('--model_path', default='train_results/trained_models/epoch_80.ckpt', help='Model checkpoint path')
FLAGS = parser.parse_args()


# DEFAULT SETTINGS
pretrained_model_path = FLAGS.model_path # 获取保存好的模型
hdf5_data_dir = os.path.join(BASE_DIR, './hdf5_data') # 获取h5数据集
ply_data_dir = os.path.join(BASE_DIR, './PartAnnotation') # 导入测试数据集
gpu_to_use = 0
output_dir = os.path.join(BASE_DIR, './test_results')
output_verbose = True   # If true, output all color-coded part segmentation obj files

# MAIN SCRIPT
point_num = 3000            # the max number of points in the all testing data shapes
batch_size = 1

test_file_list = os.path.join(BASE_DIR, 'testing_ply_file_list.txt')
""" testing_ply_file_list.txt为从PartAnnotation数据集中采样出的2874个数据,分别包括
    点云数据 / 分割数据 / 实例类别编号"""

oid2cpid = json.load(open(os.path.join(hdf5_data_dir, 'overallid_to_catid_partid.json'), 'r'))
"""oid2cpid读取物体零件编号
   [["02691156", 1], ["02691156", 2],....]"""

object2setofoid = {} # oid对象集
for idx in range(len(oid2cpid)):
    objid, pid = oid2cpid[idx] # oid对象标识符  pid编号
    if not objid in object2setofoid.keys():
        object2setofoid[objid] = []
    object2setofoid[objid].append(idx) # 创建一个字典,将每个物体编号按顺序0-49索引排序
'''{‘02691156’:[0,1,2,3],'02247898':[4,5],....}'''

''' 获取16类物体和编号的文件,并划分到两个列表中'''
all_obj_cat_file = os.path.join(hdf5_data_dir, 'all_object_categories.txt')
fin = open(all_obj_cat_file, 'r')

'''
    objcats: ['02691156','14546466'....]
    objnames: ['Airplane','Bag'....]
'''
lines = [line.rstrip() for line in fin.readlines()]
objcats = [line.split()[1] for line in lines]
objnames = [line.split()[0] for line in lines]
on2oid = {objcats[i]:i for i in range(len(objcats))}  # on2oid为物体编号对应的索引,共16个
fin.close()

color_map_file = os.path.join(hdf5_data_dir, 'part_color_mapping.json')
color_map = json.load(open(color_map_file, 'r')) # 获取颜色

NUM_OBJ_CATS = 16 # 16个物体对象
NUM_PART_CATS = 50 # 分割成50个类别

'''为物体零件进行分类1-50类对应 {'03642806_2':29,'03642806_1':28,.....}'''
cpid2oid = json.load(open(os.path.join(hdf5_data_dir, 'catid_partid_to_overallid.json'), 'r'))

3、load_pts_seg_files方法

def load_pts_seg_files(pts_file, seg_file, catid):
    with open(pts_file, 'r') as f:
        pts_str = [item.rstrip() for item in f.readlines()]
        pts = np.array([np.float32(s.split()) for s in pts_str], dtype=np.float32)
    with open(seg_file, 'r') as f:
        '''在单独一个物体中以1,2,3讲不同零件进行分类,得出的零件索引[2 2 2 1 1 1 1 1 .....]'''
        part_ids = np.array([int(item.rstrip()) for item in f.readlines()], dtype=np.uint8)
        '''cpid2oid为每个物体零件对应的0-50类编号,将单个物体零件的分类通过cpid2oid转换为总的50类别'''
        seg = np.array([cpid2oid[catid+'_'+str(x)] for x in part_ids])
    return pts, seg

4、predict()方法

def predict():
    is_training = False
    
    with tf.device('/gpu:'+str(gpu_to_use)):
        pointclouds_ph, input_label_ph = placeholder_inputs()
        is_training_ph = tf.placeholder(tf.bool, shape=())

        # simple model
        # pointclouds_ph:(1,3000,3)  seg_pred:(1,16)
        pred, seg_pred, end_points = model.get_model(pointclouds_ph, input_label_ph, \
                cat_num=NUM_OBJ_CATS, part_num=NUM_PART_CATS, is_training=is_training_ph, \
                batch_size=batch_size, num_point=point_num, weight_decay=0.0, bn_decay=None)
        
    # Add ops to save and restore all the variables.//添加操作用来保存和重现所有变量
    saver = tf.train.Saver()

    # Later, launch the model, use the saver to restore variables from disk, and
    # do some work with the model.
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.Session(config=config) as sess:
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        flog = open(os.path.join(output_dir, 'log.txt'), 'w')

        # Restore variables from disk.
        printout(flog, 'Loading model %s' % pretrained_model_path)
        saver.restore(sess, pretrained_model_path) # 导入训练好的模型
        printout(flog, 'Model restored.')
        
        # Note: the evaluation for the model with BN has to have some statistics
        # Using some test datas as the statistics
        batch_data = np.zeros([batch_size, point_num, 3]).astype(np.float32) # [1,3000,3]

        total_acc = 0.0
        total_seen = 0
        total_acc_iou = 0.0

        # NUM_OBJ_CATS=16
        total_per_cat_acc = np.zeros((NUM_OBJ_CATS)).astype(np.float32) # 每一类正确的个数
        total_per_cat_iou = np.zeros((NUM_OBJ_CATS)).astype(np.float32) # 每一类的IOU
        total_per_cat_seen = np.zeros((NUM_OBJ_CATS)).astype(np.int32) # 每一类测试的总个数

        ffiles = open(test_file_list, 'r') # 获取测试用的数据集,并进行预处理,将其划分为3类列表
        lines = [line.rstrip() for line in ffiles.readlines()]
        pts_files = [line.split()[0] for line in lines] # 获取点云文件路径(有2874个路径)
        seg_files = [line.split()[1] for line in lines] # 获取seg文件路径
        labels = [line.split()[2] for line in lines] # 获取物体类别编号
        ffiles.close()

        len_pts_files = len(pts_files)
        for shape_idx in range(len_pts_files):
            if shape_idx % 100 == 0:
                printout(flog, '%d/%d ...' % (shape_idx, len_pts_files))

            # on2oid为物体编号对应的索引,共16个,获取当前数据集的编号对应索引,将其转换为独热编码
            cur_gt_label = on2oid[labels[shape_idx]]

            cur_label_one_hot = np.zeros((1, NUM_OBJ_CATS), dtype=np.float32)
            cur_label_one_hot[0, cur_gt_label] = 1

            # 根据shape_idx将pts和seg文件读取处理
            pts_file_to_load = os.path.join(ply_data_dir, pts_files[shape_idx])
            seg_file_to_load = os.path.join(ply_data_dir, seg_files[shape_idx])

            # 将各物体编号都统一到1-50类当中,这个操作非常关键!!!
            pts, seg = load_pts_seg_files(pts_file_to_load, seg_file_to_load, objcats[cur_gt_label])
            ori_point_num = len(seg)

            batch_data[0, ...] = pc_augment_to_point_num(pc_normalize(pts), point_num)

            # 预测出的label和seg
            label_pred_val, seg_pred_res = sess.run([pred, seg_pred], feed_dict={
                        pointclouds_ph: batch_data,
                        input_label_ph: cur_label_one_hot, 
                        is_training_ph: is_training,
                    })

            # 将预测出的label得出
            label_pred_val = np.argmax(label_pred_val[0, :])
            
            seg_pred_res = seg_pred_res[0, ...] # 进行降维处理

            # 将该物体的索引提取出来
            # objacts:['02691156','02773838',...]
            # object2setofoid:{'02691156':[0,1,2,3],'02773838':[4,5],...}
            iou_oids = object2setofoid[objcats[cur_gt_label]]

            # 创建一个0-49的数组,剔除12,13,14,15
            non_cat_labels = list(set(np.arange(NUM_PART_CATS)).difference(set(iou_oids)))

            mini = np.min(seg_pred_res) # 获取预测中的最小值
            seg_pred_res[:, non_cat_labels] = mini - 1000 # 将除12,,1,14,15的其他标签都减小

            # 比较12,13,14,15这个位置得数,取最大判断为该类
            seg_pred_val = np.argmax(seg_pred_res, axis=1)[:ori_point_num]

            # 预测的类与正确实际的类作比较,得出seg的正确率
            seg_acc = np.mean(seg_pred_val == seg)

            # 将分割的正确率进行累加
            total_acc += seg_acc

            # 测试总的个数
            total_seen += 1

            total_per_cat_seen[cur_gt_label] += 1
            total_per_cat_acc[cur_gt_label] += seg_acc

            # 预测类与正确的比较,相等为1,不等为0
            mask = np.int32(seg_pred_val == seg)

            # 计算IOU = n_intersect/(n_pred + n_seg - n_intersect)
            # n_pred:预测的标签的个数
            # n_seg:实际的标签个数
            # n_intersect判断正确的标签个数
            total_iou = 0.0
            iou_log = ''
            for oid in iou_oids:
                n_pred = np.sum(seg_pred_val == oid)
                n_gt = np.sum(seg == oid)
                n_intersect = np.sum(np.int32(seg == oid) * mask)
                n_union = n_pred + n_gt - n_intersect
                iou_log += '_' + str(n_pred)+'_'+str(n_gt)+'_'+str(n_intersect)+'_'+str(n_union)+'_'
                if n_union == 0:
                    total_iou += 1
                    iou_log += '_1\n'
                else:
                    total_iou += n_intersect * 1.0 / n_union
                    iou_log += '_'+str(n_intersect * 1.0 / n_union)+'\n'

            avg_iou = total_iou / len(iou_oids)
            total_acc_iou += avg_iou
            total_per_cat_iou[cur_gt_label] += avg_iou

            # 对预测结果,保存在obj文件中
            if output_verbose:
                output_color_point_cloud(pts, seg, os.path.join(output_dir, str(shape_idx)+'_gt.obj'))
                output_color_point_cloud(pts, seg_pred_val, os.path.join(output_dir, str(shape_idx)+'_pred.obj'))
                output_color_point_cloud_red_blue(pts, np.int32(seg == seg_pred_val), 
                        os.path.join(output_dir, str(shape_idx)+'_diff.obj'))

                with open(os.path.join(output_dir, str(shape_idx)+'.log'), 'w') as fout:
                    fout.write('Total Point: %d\n\n' % ori_point_num)
                    fout.write('Ground Truth: %s\n' % objnames[cur_gt_label])
                    fout.write('Predict: %s\n\n' % objnames[label_pred_val])
                    fout.write('Accuracy: %f\n' % seg_acc)
                    fout.write('IoU: %f\n\n' % avg_iou)
                    fout.write('IoU details: %s\n' % iou_log)

        printout(flog, 'Accuracy: %f' % (total_acc / total_seen))
        printout(flog, 'IoU: %f' % (total_acc_iou / total_seen))

        for cat_idx in range(NUM_OBJ_CATS):
            printout(flog, '\t ' + objcats[cat_idx] + ' Total Number: ' + str(total_per_cat_seen[cat_idx]))
            if total_per_cat_seen[cat_idx] > 0:
                printout(flog, '\t ' + objcats[cat_idx] + ' Accuracy: ' + \
                        str(total_per_cat_acc[cat_idx] / total_per_cat_seen[cat_idx]))
                printout(flog, '\t ' + objcats[cat_idx] + ' IoU: '+ \
                        str(total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx]))

       

三、结果查看

在路径…/part_seg/test_results文件中可以查看测试结果
在这里插入图片描述
1、生成的是obj文件:
对同一个椅子obj文件,可以通过CloudCompare和meshlab查看,效果如下两张图(meshlab查看时可以查看到颜色信息,CloudCompare只能查看到灰色点云,推荐使用meshlab)
在这里插入图片描述
在这里插入图片描述
2、路径…/part_seg/test_results/log.txt文件中存放的是测试时的日志数据

在这里插入图片描述
在这里插入图片描述

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PointNet2是一种针对点云分类和分割任务的深度学习框架。PointNet2_Part_Seg_SSG是基于PointNet2框架的一个应用,用于点云部分分割任务。 PointNet2使用了一种层级的神经网络结构,能够有效地处理无序的点云数据。它将点云分为多个局部区域,对每个区域进行特征提取,最后整合局部特征得到全局特征表达。这种设计能够提取点云的局部和全局特征,从而实现对点云数据的分类和分割PointNet2_Part_Seg_SSG是PointNet2框架的一种改进,主要针对点云的部分分割任务。它使用了SSG(Single-Scale Grouping)模块,通过分组聚合点的特征,从而对点云进行细分。SSG模块首先选择每个局部区域中的中心点,并将其他点分配给最近的中心点。然后,SSG模块对每个中心点的邻域进行特征提取和聚合,得到该局部区域的特征表示。最后,通过进一步的卷积和池化操作,得到点云的全局特征表示。 在训练过程中,PointNet2_Part_Seg_SSG使用交叉熵损失函数来度量预测的分割结果与真实标签之间的差异。通过反向传播算法,可以优化网络的参数,使得网络能够更好地学习点云的特征表示和分割任务。 总的来说,PointNet2_Part_Seg_SSG是基于PointNet2框架的一个改进版本,专门用于点云的部分分割任务。它通过采用SSG模块,能够对点云进行更精细的细分和特征提取,从而提高了点云分割任务的准确性和效果。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值