点云深度学习系列5——pointnet++文章及代码分析

点云深度学习系列5——pointnet++文章及代码分析

2018年05月29日 08:48:46 李敏乐1992 阅读数:1581 标签: point cloudpointnetdeep learning更多

个人分类: point clouddeep learning

大家好。

PointNet++是PointNet的升级版本,增加了对局部信息的感知能力。体现到代码上的话,变化还是比较多的,我们以分类为例,对结构和代码进行分析。

网络结构

 

首先是网络结构方面,复习前任PointNet网络结构的,请点这里

改进版去掉了T-net,在网络层次上变多了,但是更加组织有序。

 

 
  1. def get_model(point_cloud, is_training, bn_decay=None):  

  2.    """ Classification PointNet, input is BxNx3, output Bx40 """  

  3.    batch_size = point_cloud.get_shape()[0].value  

  4.    num_point = point_cloud.get_shape()[1].value  

  5.    end_points = {}  

  6.    l0_xyz = point_cloud  

  7.    l0_points = None  

  8.    end_points['l0_xyz'] = l0_xyz  

  9.  
  10.    # Set abstraction layers  

  11.    # Note: When using NCHW for layer 2, we see increased GPU memory usage (in TF1.4).  

  12.    # So we only use NCHW for layer 1 until this issue can be resolved.  

  13.    l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, npoint=512, radius=0.2, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer1', use_nchw=True)  

  14.    l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, npoint=128, radius=0.4, nsample=64, mlp=[128,128,256], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer2')  

  15.    l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, npoint=None, radius=None, nsample=None, mlp=[256,512,1024], mlp2=None, group_all=True, is_training=is_training, bn_decay=bn_decay, scope='layer3')  

  16.  
  17.    # Fully connected layers  

  18.    net = tf.reshape(l3_points, [batch_size, -1])  

  19.    net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay)  

  20.    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp1')  

  21.    net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay)  

  22.    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp2')  

  23.    net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')  

  24.  
  25.    return net, end_points

 

上述代码部分依然分成特征提取和分类任务两个部分来看。

 

特征提取部分即代码中的Set abstraction layers,值得注意的是它没有用T-net,而是直接对点云进行处理。由三个pointnet_sa_module模块组成,每个模块内包含3层mlp和1个pooling层,所以共总用了9个mlp层用于特征提取。

pointnet_sa_module模块的代码如下:

 
  1. def pointnet_sa_module(xyz, points, npoint, radius, nsample, mlp, mlp2, group_all, is_training, bn_decay, scope, bn=True, pooling='max', knn=False, use_xyz=True, use_nchw=False):

  2.    ''' PointNet Set Abstraction (SA) Module

  3.        Input:

  4.            xyz: (batch_size, ndataset, 3) TF tensor

  5.            points: (batch_size, ndataset, channel) TF tensor

  6.            npoint: int32 -- #points sampled in farthest point sampling中心点的个数

  7.            radius: float32 -- search radius in local region

  8.            nsample: int32 -- how many points in each local region

  9.            mlp: list of int32 -- output size for MLP on each point

  10.            mlp2: list of int32 -- output size for MLP on each region

  11.            group_all: bool -- group all points into one PC if set true, OVERRIDE

  12.                npoint, radius and nsample settings

  13.            use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features

  14.            use_nchw: bool, if True, use NCHW data format for conv2d, which is usually faster than NHWC format

  15.        Return:

  16.            new_xyz: (batch_size, npoint, 3) TF tensor

  17.            new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor

  18.            idx: (batch_size, npoint, nsample) int32 -- indices for local regions

  19.    '''

  20.    data_format = 'NCHW' if use_nchw else 'NHWC'

  21.    with tf.variable_scope(scope) as sc:        # Sample and Grouping

  22.        if group_all:

  23.            nsample = xyz.get_shape()[1].value

  24.            new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, use_xyz)        else:

  25.            new_xyz, new_points, idx, grouped_xyz = sample_and_group(npoint, radius, nsample, xyz, points, knn, use_xyz)        # Point Feature Embedding

  26.        if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])        for i, num_out_channel in enumerate(mlp):

  27.            new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],

  28.                                        padding='VALID', stride=[1,1],

  29.                                        bn=bn, is_training=is_training,

  30.                                        scope='conv%d'%(i), bn_decay=bn_decay,

  31.                                        data_format=data_format)

  32.        if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])        # Pooling in Local Regions

  33.        if pooling=='max':

  34.            new_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')        elif pooling=='avg':

  35.            new_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')        elif pooling=='weighted_avg':            with tf.variable_scope('weighted_avg'):

  36.                dists = tf.norm(grouped_xyz,axis=-1,ord=2,keep_dims=True)

  37.                exp_dists = tf.exp(-dists * 5)

  38.                weights = exp_dists/tf.reduce_sum(exp_dists,axis=2,keep_dims=True) # (batch_size, npoint, nsample, 1)

  39.                new_points *= weights # (batch_size, npoint, nsample, mlp[-1])

  40.                new_points = tf.reduce_sum(new_points, axis=2, keep_dims=True)        elif pooling=='max_and_avg':

  41.            max_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')

  42.            avg_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')

  43.            new_points = tf.concat([avg_points, max_points], axis=-1)        # [Optional] Further Processing

  44.        if mlp2 is not None:            if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])            for i, num_out_channel in enumerate(mlp2):

  45.                new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],

  46.                                            padding='VALID', stride=[1,1],

  47.                                            bn=bn, is_training=is_training,

  48.                                            scope='conv_post_%d'%(i), bn_decay=bn_decay,

  49.                                            data_format=data_format)

  50.            if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])

  51.  
  52.        new_points = tf.squeeze(new_points, [2]) # (batch_size, npoints, mlp2[-1])

  53.        return new_xyz, new_points, idx

每个模块中先采样,找邻域,然后用三层1*1卷积构成的全连接层进行特征提取,最后做池化,输出。

分类任务部分与PointNet差别不大,不再赘述。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值