Pointnet(part_seg)train.py,test.py代码随记

train.py

我将代码全部简化,将关键步骤全部列出

hdf5_data_dir = 数据集路径                  #读取数据集的路径
 
创建os.mkdir(train_result)                 #创建train_result文件夹
 
color_map_file  =  part_color_mapping.json #读取颜色json文件路径,一共50类
 
color_map = json.load()                    #读取.json文件内容
 
读取overallid_to_catid_partid.json 列表形式  #读取物体零件编号

training_file_list = train_hdf5 路径 
testing_file_list  = test_hdf5 路径

model_storage_path = 'trained_models'      #在train_result下创建了一个trained_models文件夹,
                                           #用于存放训练好的模型
创建logs日志文件夹
创建summaries文件夹,可视化

def train():
    pointclouds_ph = (32,2048,3)
    input_label_ph = (32,16)
    label_ph = (32)
    seg_ph = (32,2048)        
    """以上是导入输入数据的占位符"""
    batch = 初始化变量0
    learning_rate = 指数衰减学习率
    bn_decay = 批标准化衰减率
    
    labels_pred , seg_pred, end_points = model.get_model( (32,2048,3),(32,16),...)
    #模型训练
    loss = get_loss()
    #计算损失
    train_variables = tf.trainable_variables() 
    #可训练参数
    trainer = tf.train.AdamOptimizer(learning_rate) 
    #优化器优化
    train_op = trainer.minimize(loss, var_list=train_variables, global_step=batch)
    #梯度优化,更新var_list最大程度减少损失

    saver = tf.train.Saver() #保存和加载模型
    
    init = tf.global_variables_initializer() #全局变量初始化
    sess.run(init)                  #图结构创建好了,开始会话       

    for epoch in range(training_epoches):  #训练次数
      eval_one_epoch(epoch)

      train_file_idx = np.arange(0,6)
      打乱顺序
      train_one_epoch(train_one_idx , epoch)
      if(epoch+1) %10 == 0:
           cp_filename = saver.save(sess , 保存路径)  #保存训练模型

    def eval_one_epoch(epoch_num):
       total_label_acc_per_cat = np.zeros[16]  #每一类物体分类标签的正确数
       total_seg_acc_per_cat = np.zeros[16]   #每一类分割正确数
       total_seen_per_cat = np.zeros[16]      #每类个数

       for i in range(num_test_file):
           cur_data = (2048,2048,3)   #测试集的点云数据
           cur_labels =2048,16#点云数据物体对应的16类
           cur_seg  =20482048#每个点对应的50类其中之一
           
           cur_labels_one_hot = convert_label_to_one_hot(cur_labels)
           """将label都换为one_hot形式"""
    
           for j in range(num_batch):   #按批次运行
              beginidx-----endidx      #开始到结束的索引
              loss = sess.run()
              
              per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx:endidx,...],axis=1 )
              """求每个物体的零件正确率"""
              average_part_acc = np.mean(per_instance_part_acc)
              """求这32个物体的平均零件正确率"""
              per_instance_label_pred = np.argmax(label_pred_val, axis=1)
              """ 求出这32个物体对类别预测的标签 """
              total_label_acc += np.mean(np.float32(per_instance_label_pred == cur_labels[begidx: endidx, ...]))
              """ 算出预测标签的正确率并求平均进行累加"""
              total_seg_acc += average_part_acc 
              """将平均零件分割正确率累加"""
              for shape_idx in range(begidx, endidx):
                  total_seen_per_cat[cur_labels[shape_idx]] += 1
                  """test过的每一类的个数"""
                  total_label_acc_per_cat[cur_labels[shape_idx]]+=np.int32(per_instance_label_pred[shape_idx-begidx] == cur_labels[shape_idx])
                  """每一类标签判断正确的个数:预测标签与正确标签对比,如果正确就在相应位置+1"""
                  total_seg_acc_per_cat[cur_labels[shape_idx]] += per_instance_part_acc[shape_idx - begidx]
                  """将每个物体分割的正确率累加"""
             
                  total_loss = total_loss * 1.0 / total_seen
                  total_label_loss = total_label_loss * 1.0 / total_seen
                  total_seg_loss = total_seg_loss * 1.0 / total_seen
                  total_label_acc = total_label_acc * 1.0 / total_seen
                  total_seg_acc = total_seg_acc * 1.0 / total_seen
      

train_one_proch比eval_one_peoch 多一个优化器过程

test.py

自行定义命令参数获取 model_path ,保存的训练模型

pretrained_model_path = FLAGS.model_path  #获取保存好的模型
hdf5_data_dir = './hdf5_data'  # 获取h5数据集
ply_data_dir = './PartAnnotation' # 导入测试数据集

test_file_list = os.path.join(BASE_DIR, 'testing_ply_file_list.txt')
""" testing_ply_file_list.txt为从PartAnnotation数据集中采样出的2874个数据,分别包括
    点云数据 / 分割数据 / 实例类别编号"""

oid2cpid =  'overallid_to_catid_partid.json'
"""oid2cpid读取物体零件编号
   [["02691156", 1], ["02691156", 2],....]"""  
   
object2setofoid = {}   #oid对象集
for idx in range(len(oid2cpid)):
    objid, pid = oid2cpid[idx]   #objid对象标识符  pid编号
    if not objid in object2setofoid.keys():
        object2setofoid[objid] = []
    object2setofoid[objid].append(idx)
"""创建一个字典,将每个物体编号按顺序0~49索引排序
   {'02691156':[0,1,2,3],'02773838':[4,5],.....}"""

all_obj_cat_file =  'all_object_categories.txt' 
 获取16类物体和编号的文件,并分别划分到两个列表中
objcats =  split()[0]
"""['02691156','02773838',......]"""  
objnames =  split()[1]
"""['Airplane','Bag',......]"""

color_map = json.load('part_color_mapping.json')   获取颜色

cpid2oid = 'catid_partid_to_overallid.json'  
"""cpid2oid为对物体零件进行分类1~50类对应
   {"03642806_2": 29, "03642806_1": 28,...."""

------------------------------------数据集的前期处理全部完成-----------------------------------
def predict():
   pointclouds_ph = (1,3000,3)
   input_label_ph = (1,16)
   
   pred , seg_pred , end_points = get_model(pointclouds_ph, input_label_ph,...)
   """模型占位符"""
   saver = tf.train.Saver()
   """添加操作用来保存和重现所有变量"""
   
   with tf.Seesion(config=config) as sess:
        saver.restore(sess, pretrained_model_path)
        """导入训练好的模型"""
        batch_data = np.zeros[1,3000,3]
        
        total_per_cat_acc = np.zeros(16)
        """每一类正确的个数"""
        total_per_cat_iou = np.zeros(16)
        """ 每一类的IOU"""
        total_per_cat_seen = np.zeros(16)
        """ 每一类测试的总个数"""

       获取测试用的数据集test_file_list,并进行预处理,将其划分为3类列表
       pts_files = split()[0] 
       """获取的点云文件路径"""
       seg_files = split()[1]
       """获取seg文件路径"""
       labels = split()[2]
       """ 获取物体类别编号"""
       
       """开始逐个对测试数据集中的数据进行操作,测试数据有2874个"""
       for shape_idx in range(len_pts_files):
             cur_gt_label = on2oid[labels[shape_idx]]
             """ on2oid为物体编号对应索引,总共有16个,获取当前数据集的编号对应索引"""
             将其转换为独热编码
            
            pts_file_to_load = os.path.join(ply_data_dir, pts_files[shape_idx])
            seg_file_to_load = os.path.join(ply_data_dir, seg_files[shape_idx])
            """根据shape_idx将pts文件和seg文件读取出来"""
            
            pts, seg = load_pts_seg_files(pts_file_to_load, seg_file_to_load, objcats[cur_gt_label])
            """将各物体编号都统一到1~50类当中,这个操作非常关键!!!!! """
             
def load_pts_seg_files(pts_file, seg_file, catid):
     with open(pts_file, 'r') as f:
          pts_str = [item.rstrip() for item in f.readlines()]
          pts = np.array([np.float32(s.split()) for s in pts_str], dtype=np.float32)
          a = len(pts)   
     with open(seg_file, 'r') as f:
          part_ids = np.array([int(item.rstrip()) for item in f.readlines()], dtype=np.uint8)
         """在单独一个物体中以1,2,3将不同零件进行分类,得出的零件索引[2 2 2 1 1 1 1 1 ....]"""
          seg = np.array([cpid2oid[catid+'_'+str(x)] for x in part_ids])
         """cpid2oid为每个物体零件对应的0~50类编号,将单个物体零件的分类通过cpid2oid转换为总的50类别"""
         
         label_pred_val , seg_pred_res = sess.run()
         """ 预测出的label 和 seg"""
         label_pred_val = np.argmax(label_pred_val[0, :]) 
         """将预测出的label得出"""
         seg_pred_res = seg_pred_res[0,....]   #进行降维处理
         c = seg_pred_res.shaoe    #(3000,50)
         
         iou_oids = object2setofoid[objcats[cur_gt_label]]
         """ 将该物体的零件索引提取出来
             objacts:['02691156','02773838',......]
             object2setofoid:{'02691156':[0,1,2,3],'02773838':[4,5],.....}
             [12,13,14,15]"""
         non_cat_labels = list(set(np.arange(NUM_PART_CATS)).difference(set(iou_oids)))   
         """创建一个0~49的数组,剔除12,13,14,15"""
         mini = np.min(seg_pred_res)  #获取预测中的最小值
         seg_pred_res[:, non_cat_labels] = mini - 1000  #将除12,13,14,15的其他标签都减小

         seg_pred_val = np.argmax(seg_pred_res, axis=1)[:ori_point_num]
         """比较12,13,14,15这个位置的数,取最大判断为该类"""
         
         seg_acc = np.mean(seg_pred_val == seg)
         """预测的类与正确实际的类做比较,得出seg的正确率"""
         
         total_acc += seg_acc
         """将分割的正确率进行累加"""
         total_seen += 1
         """ 测试总的个数"""
         total_per_cat_seen[cur_gt_label] += 1
         total_per_cat_acc[cur_gt_label] += seg_acc

         mask = np.int32(seg_pred_val == seg)
         """预测类与正确的比较,相等为1,不等为0"""         
         
         计算IOU = n_intersect/(n_pred + n_seg - n_intersect)
         n_pred = 预测的12标签的个数
         n_seg  = 实际的12标签的个数
         n_intersect = 判断正确的12标签的个数
         
        """对预测结果,保存在obj文件"""
        if output_verbose:
                output_color_point_cloud(pts, seg, os.path.join(output_dir, str(shape_idx)+'_gt.obj'))
                output_color_point_cloud(pts, seg_pred_val, os.path.join(output_dir, str(shape_idx)+'_pred.obj'))
                output_color_point_cloud_red_blue(pts, np.int32(seg == seg_pred_val), os.path.join(output_dir, str(shape_idx)+'_diff.obj'))






  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值