train.py
我将代码全部简化,将关键步骤全部列出
hdf5_data_dir = 数据集路径 #读取数据集的路径
创建os.mkdir(train_result) #创建train_result文件夹
color_map_file = part_color_mapping.json #读取颜色json文件路径,一共50类
color_map = json.load() #读取.json文件内容
读取overallid_to_catid_partid.json 列表形式 #读取物体零件编号
training_file_list = train_hdf5 路径
testing_file_list = test_hdf5 路径
model_storage_path = 'trained_models' #在train_result下创建了一个trained_models文件夹,
#用于存放训练好的模型
创建logs日志文件夹
创建summaries文件夹,可视化
def train():
pointclouds_ph = (32,2048,3)
input_label_ph = (32,16)
label_ph = (32)
seg_ph = (32,2048)
"""以上是导入输入数据的占位符"""
batch = 初始化变量0
learning_rate = 指数衰减学习率
bn_decay = 批标准化衰减率
labels_pred , seg_pred, end_points = model.get_model( (32,2048,3),(32,16),...)
#模型训练
loss = get_loss()
#计算损失
train_variables = tf.trainable_variables()
#可训练参数
trainer = tf.train.AdamOptimizer(learning_rate)
#优化器优化
train_op = trainer.minimize(loss, var_list=train_variables, global_step=batch)
#梯度优化,更新var_list最大程度减少损失
saver = tf.train.Saver() #保存和加载模型
init = tf.global_variables_initializer() #全局变量初始化
sess.run(init) #图结构创建好了,开始会话
for epoch in range(training_epoches): #训练次数
eval_one_epoch(epoch)
train_file_idx = np.arange(0,6)
打乱顺序
train_one_epoch(train_one_idx , epoch)
if(epoch+1) %10 == 0:
cp_filename = saver.save(sess , 保存路径) #保存训练模型
def eval_one_epoch(epoch_num):
total_label_acc_per_cat = np.zeros[16] #每一类物体分类标签的正确数
total_seg_acc_per_cat = np.zeros[16] #每一类分割正确数
total_seen_per_cat = np.zeros[16] #每类个数
for i in range(num_test_file):
cur_data = (2048,2048,3) #测试集的点云数据
cur_labels =(2048,16) #点云数据物体对应的16类
cur_seg = (2048,2048) #每个点对应的50类其中之一
cur_labels_one_hot = convert_label_to_one_hot(cur_labels)
"""将label都换为one_hot形式"""
for j in range(num_batch): #按批次运行
beginidx-----endidx #开始到结束的索引
loss = sess.run()
per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx:endidx,...],axis=1 )
"""求每个物体的零件正确率"""
average_part_acc = np.mean(per_instance_part_acc)
"""求这32个物体的平均零件正确率"""
per_instance_label_pred = np.argmax(label_pred_val, axis=1)
""" 求出这32个物体对类别预测的标签 """
total_label_acc += np.mean(np.float32(per_instance_label_pred == cur_labels[begidx: endidx, ...]))
""" 算出预测标签的正确率并求平均进行累加"""
total_seg_acc += average_part_acc
"""将平均零件分割正确率累加"""
for shape_idx in range(begidx, endidx):
total_seen_per_cat[cur_labels[shape_idx]] += 1
"""test过的每一类的个数"""
total_label_acc_per_cat[cur_labels[shape_idx]]+=np.int32(per_instance_label_pred[shape_idx-begidx] == cur_labels[shape_idx])
"""每一类标签判断正确的个数:预测标签与正确标签对比,如果正确就在相应位置+1"""
total_seg_acc_per_cat[cur_labels[shape_idx]] += per_instance_part_acc[shape_idx - begidx]
"""将每个物体分割的正确率累加"""
total_loss = total_loss * 1.0 / total_seen
total_label_loss = total_label_loss * 1.0 / total_seen
total_seg_loss = total_seg_loss * 1.0 / total_seen
total_label_acc = total_label_acc * 1.0 / total_seen
total_seg_acc = total_seg_acc * 1.0 / total_seen
train_one_proch比eval_one_peoch 多一个优化器过程
test.py
自行定义命令参数获取 model_path ,保存的训练模型
pretrained_model_path = FLAGS.model_path #获取保存好的模型
hdf5_data_dir = './hdf5_data' # 获取h5数据集
ply_data_dir = './PartAnnotation' # 导入测试数据集
test_file_list = os.path.join(BASE_DIR, 'testing_ply_file_list.txt')
""" testing_ply_file_list.txt为从PartAnnotation数据集中采样出的2874个数据,分别包括
点云数据 / 分割数据 / 实例类别编号"""
oid2cpid = 'overallid_to_catid_partid.json'
"""oid2cpid读取物体零件编号
[["02691156", 1], ["02691156", 2],....]"""
object2setofoid = {} #oid对象集
for idx in range(len(oid2cpid)):
objid, pid = oid2cpid[idx] #objid对象标识符 pid编号
if not objid in object2setofoid.keys():
object2setofoid[objid] = []
object2setofoid[objid].append(idx)
"""创建一个字典,将每个物体编号按顺序0~49索引排序
{'02691156':[0,1,2,3],'02773838':[4,5],.....}"""
all_obj_cat_file = 'all_object_categories.txt'
获取16类物体和编号的文件,并分别划分到两个列表中
objcats = split()[0]
"""['02691156','02773838',......]"""
objnames = split()[1]
"""['Airplane','Bag',......]"""
color_map = json.load('part_color_mapping.json') 获取颜色
cpid2oid = 'catid_partid_to_overallid.json'
"""cpid2oid为对物体零件进行分类1~50类对应
{"03642806_2": 29, "03642806_1": 28,...."""
------------------------------------数据集的前期处理全部完成-----------------------------------
def predict():
pointclouds_ph = (1,3000,3)
input_label_ph = (1,16)
pred , seg_pred , end_points = get_model(pointclouds_ph, input_label_ph,...)
"""模型占位符"""
saver = tf.train.Saver()
"""添加操作用来保存和重现所有变量"""
with tf.Seesion(config=config) as sess:
saver.restore(sess, pretrained_model_path)
"""导入训练好的模型"""
batch_data = np.zeros[1,3000,3]
total_per_cat_acc = np.zeros(16)
"""每一类正确的个数"""
total_per_cat_iou = np.zeros(16)
""" 每一类的IOU"""
total_per_cat_seen = np.zeros(16)
""" 每一类测试的总个数"""
获取测试用的数据集test_file_list,并进行预处理,将其划分为3类列表
pts_files = split()[0]
"""获取的点云文件路径"""
seg_files = split()[1]
"""获取seg文件路径"""
labels = split()[2]
""" 获取物体类别编号"""
"""开始逐个对测试数据集中的数据进行操作,测试数据有2874个"""
for shape_idx in range(len_pts_files):
cur_gt_label = on2oid[labels[shape_idx]]
""" on2oid为物体编号对应索引,总共有16个,获取当前数据集的编号对应索引"""
将其转换为独热编码
pts_file_to_load = os.path.join(ply_data_dir, pts_files[shape_idx])
seg_file_to_load = os.path.join(ply_data_dir, seg_files[shape_idx])
"""根据shape_idx将pts文件和seg文件读取出来"""
pts, seg = load_pts_seg_files(pts_file_to_load, seg_file_to_load, objcats[cur_gt_label])
"""将各物体编号都统一到1~50类当中,这个操作非常关键!!!!! """
def load_pts_seg_files(pts_file, seg_file, catid):
with open(pts_file, 'r') as f:
pts_str = [item.rstrip() for item in f.readlines()]
pts = np.array([np.float32(s.split()) for s in pts_str], dtype=np.float32)
a = len(pts)
with open(seg_file, 'r') as f:
part_ids = np.array([int(item.rstrip()) for item in f.readlines()], dtype=np.uint8)
"""在单独一个物体中以1,2,3将不同零件进行分类,得出的零件索引[2 2 2 1 1 1 1 1 ....]"""
seg = np.array([cpid2oid[catid+'_'+str(x)] for x in part_ids])
"""cpid2oid为每个物体零件对应的0~50类编号,将单个物体零件的分类通过cpid2oid转换为总的50类别"""
label_pred_val , seg_pred_res = sess.run()
""" 预测出的label 和 seg"""
label_pred_val = np.argmax(label_pred_val[0, :])
"""将预测出的label得出"""
seg_pred_res = seg_pred_res[0,....] #进行降维处理
c = seg_pred_res.shaoe #(3000,50)
iou_oids = object2setofoid[objcats[cur_gt_label]]
""" 将该物体的零件索引提取出来
objacts:['02691156','02773838',......]
object2setofoid:{'02691156':[0,1,2,3],'02773838':[4,5],.....}
[12,13,14,15]"""
non_cat_labels = list(set(np.arange(NUM_PART_CATS)).difference(set(iou_oids)))
"""创建一个0~49的数组,剔除12,13,14,15"""
mini = np.min(seg_pred_res) #获取预测中的最小值
seg_pred_res[:, non_cat_labels] = mini - 1000 #将除12,13,14,15的其他标签都减小
seg_pred_val = np.argmax(seg_pred_res, axis=1)[:ori_point_num]
"""比较12,13,14,15这个位置的数,取最大判断为该类"""
seg_acc = np.mean(seg_pred_val == seg)
"""预测的类与正确实际的类做比较,得出seg的正确率"""
total_acc += seg_acc
"""将分割的正确率进行累加"""
total_seen += 1
""" 测试总的个数"""
total_per_cat_seen[cur_gt_label] += 1
total_per_cat_acc[cur_gt_label] += seg_acc
mask = np.int32(seg_pred_val == seg)
"""预测类与正确的比较,相等为1,不等为0"""
计算IOU = n_intersect/(n_pred + n_seg - n_intersect)
n_pred = 预测的12标签的个数
n_seg = 实际的12标签的个数
n_intersect = 判断正确的12标签的个数
"""对预测结果,保存在obj文件"""
if output_verbose:
output_color_point_cloud(pts, seg, os.path.join(output_dir, str(shape_idx)+'_gt.obj'))
output_color_point_cloud(pts, seg_pred_val, os.path.join(output_dir, str(shape_idx)+'_pred.obj'))
output_color_point_cloud_red_blue(pts, np.int32(seg == seg_pred_val), os.path.join(output_dir, str(shape_idx)+'_diff.obj'))