model文件夹内包含pointnet_cls.py、pointnet_cls_basic.py、pointnet_seg.py、transform_nets.py四个文件,其中,pointnet_cls.py、pointnet_cls_basic.py没啥区别,pointnet_seg.py中函数参数与pointnet_cls.py有些许区别,transform_nets.py是T-net,完成输入接受与特征提取。结构图如图:
自己的理解与代码注释(相似很多,放一个pointnet_cls.py)
import tensorflow as tf
import numpy as np
import math
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, '../utils'))
import tf_util
from transform_nets import input_transform_net, feature_transform_net
def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
labels_pl = tf.placeholder(tf.int32, shape=(batch_size))
return pointclouds_pl, labels_pl
#根据shape向pointclouds_pl, labels_pl中添加float32和int32的占位符
def get_model(point_cloud, is_training, bn_decay=None):
""" Classification PointNet, input is BxNx3, output Bx40 """
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
end_points = {
}
with tf.variable_scope