PointNet代码详解
最近在做点云深度学习的机器人抓取,这篇博客主要是把近期学习PointNet的一些总结的知识点汇总一下。
PointNet概述详见以下网址和博客,这里也就不再赘述了。
三维深度学习之pointnet系列详解
PointNet网络结构详细解析
PointNet论文理解和代码分析
PointNet论文复现及代码详解
这里着重来探讨一下内部的代码(pointnet-master\models路径下的)。
PointNet原文及Github代码下载
详细的网络结构图如下
主要讲一下应该注意的地方:
(1)网络结构内部主要分为分类和分割两部分,从 global feature 开始区分分类和分割,关于点云的分类和分割,详见点云分类与分割的区别联系。
(2)我们这边主要定义数据维度的表示为 (B, H, W, C) ,也就是Batch, Height, Width, Channel。 开始输入时是一个3D的张量 (B, n, 3),其中B即为训练的批量, n 为点云个数,3则代表了点云的(x,y,z)的3个位置,因此为了后续的卷积操作,会将其增加维度到4D张量(B, n, 3, 1),方便后面卷积核提取产生特征通道数C,(B, n, 3, C)。
(3)第一层的卷积核大小为(1, 3),因为每个点的维度都是(x, y, z),后续的所有卷积核大小均为(1, 1),因为经过第一次卷积之后数据就变为了(B, n, 1, C)。
(4)整个网络框架内部使用了两个分支网络Transform(T-Net)。T-Net对原样本进行一定的卷积和全连接等操作后得到变换矩阵并与原样本相乘完成点云内部结构的调整,但并不改变原样本数据的格式。第一次T-Net输出一个33的矩阵,第二次T-Net输出一个6464的矩阵。
阅读后续代码前请仔细看完这篇博客,务必理解里面内容
PointNet网络结构详细解析
T-Net
对应文件为“pointnet-master\models\transform_nets.py”
根据网络结构图可知输入量时B×n×3,对于input_transform来说,主要经历了以下处理过程:
卷积:64–128–1024
全连接:1024–512–256–3*K(代码中给出K=3)
最后reshape得到变换矩阵
def input_transform_net(point_cloud, is_training, bn_decay=None, K=3): """ Input (XYZ) Transform Net, input is BxNx3 gray image Return: Transformation matrix of size 3xK """ #建议阅读时忽略batch_size的维度,将张量视作一个3维矩阵[n, W, C] #其中n为点云数,也就是Height;W为矩阵宽Width;C为特征通道数 batch_size = point_cloud.get_shape()[0].value #得到训练批量 num_point = point_cloud.get_shape()[1].value #得到点云个数
input_image <span class="token operator">=</span> tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>point_cloud<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment">#扩展维度为四维张量,-1表示最后一个维度,第四个维度来表示特征通道数</span> <span class="token comment">#input_image [batch_size, num_point, 3, 1]</span> <span class="token comment">#第一次卷积,采用64个大小为[1,3]的卷积核</span> <span class="token comment">#完成之后 net表示为[batch_size, num_point, 1, 64]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>input_image<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tconv1'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment">#第二次卷积,采用128个大小为[1,1]的卷积核</span> <span class="token comment">#完成之后 net表示为[batch_size, num_point, 1, 128]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tconv2'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment">#第三次卷积,采用1024个大小为[1,1]的卷积核</span> <span class="token comment">#完成之后 net表示为[batch_size, num_point, 1, 1024]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tconv3'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment">#max pooling,扫描的模板大小为[num_point, 1],也就是每一个特征通道仅保留一个feature</span> <span class="token comment">#完成该池化操作之后的net表示为[batch_size, 1, 1, 1024]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>max_pool2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_point<span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tmaxpool'</span><span class="token punctuation">)</span> <span class="token comment">#reshape张量,即将其平面化为[batch_size, 1024]</span> net <span class="token operator">=</span> tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token punctuation">[</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>fully_connected<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tfc1'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>fully_connected<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tfc2'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment">#最后得到的net表示为[batch_size, 256],即每组只保留256个特征</span> <span class="token comment">#再经过下述的全连接操作之后得到[batch_size, 3*K]大小的transform</span> <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">'transform_XYZ'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> sc<span class="token punctuation">:</span> <span class="token keyword">assert</span><span class="token punctuation">(</span>K<span class="token operator">==</span><span class="token number">3</span><span class="token punctuation">)</span> weights <span class="token operator">=</span> tf<span class="token punctuation">.</span>get_variable<span class="token punctuation">(</span><span class="token string">'weights'</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token operator">*</span>K<span class="token punctuation">]</span><span class="token punctuation">,</span> initializer<span class="token operator">=</span>tf<span class="token punctuation">.</span>constant_initializer<span class="token punctuation">(</span><span class="token number">0.0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> biases <span class="token operator">=</span> tf<span class="token punctuation">.</span>get_variable<span class="token punctuation">(</span><span class="token string">'biases'</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">3</span><span class="token operator">*</span>K<span class="token punctuation">]</span><span class="token punctuation">,</span> initializer<span class="token operator">=</span>tf<span class="token punctuation">.</span>constant_initializer<span class="token punctuation">(</span><span class="token number">0.0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> biases <span class="token operator">+=</span> tf<span class="token punctuation">.</span>constant<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> transform <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>net<span class="token punctuation">,</span> weights<span class="token punctuation">)</span> transform <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>bias_add<span class="token punctuation">(</span>transform<span class="token punctuation">,</span> biases<span class="token punctuation">)</span> <span class="token comment">#重新塑造变换矩阵[batch_size, 3*K]为[batch_size, 3, K]</span> transform <span class="token operator">=</span> tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>transform<span class="token punctuation">,</span> <span class="token punctuation">[</span>batch_size<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> K<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">return</span> transform
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
feature_transform同input_transform类似,经历处理过程表示为:
卷积:64–128–1024
全连接:1024–512–256–64*K(代码中给出K=64)
最后reshape得到变换矩阵
这里就不再注释,大家也可以通过下面未注释的代码测试一下上一个transform有没有看懂了。
def feature_transform_net(inputs, is_training, bn_decay=None, K=64): """ Feature Transform Net, input is BxNx1xK Return: Transformation matrix of size KxK """ batch_size = inputs.get_shape()[0].value num_point = inputs.get_shape()[1].value
net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>inputs<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tconv1'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tconv2'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tconv3'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>max_pool2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_point<span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tmaxpool'</span><span class="token punctuation">)</span> net <span class="token operator">=</span> tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token punctuation">[</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>fully_connected<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tfc1'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>fully_connected<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'tfc2'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">'transform_feat'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> sc<span class="token punctuation">:</span> weights <span class="token operator">=</span> tf<span class="token punctuation">.</span>get_variable<span class="token punctuation">(</span><span class="token string">'weights'</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> K<span class="token operator">*</span>K<span class="token punctuation">]</span><span class="token punctuation">,</span> initializer<span class="token operator">=</span>tf<span class="token punctuation">.</span>constant_initializer<span class="token punctuation">(</span><span class="token number">0.0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> biases <span class="token operator">=</span> tf<span class="token punctuation">.</span>get_variable<span class="token punctuation">(</span><span class="token string">'biases'</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>K<span class="token operator">*</span>K<span class="token punctuation">]</span><span class="token punctuation">,</span> initializer<span class="token operator">=</span>tf<span class="token punctuation">.</span>constant_initializer<span class="token punctuation">(</span><span class="token number">0.0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> biases <span class="token operator">+=</span> tf<span class="token punctuation">.</span>constant<span class="token punctuation">(</span>np<span class="token punctuation">.</span>eye<span class="token punctuation">(</span>K<span class="token punctuation">)</span><span class="token punctuation">.</span>flatten<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> transform <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>net<span class="token punctuation">,</span> weights<span class="token punctuation">)</span> transform <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>bias_add<span class="token punctuation">(</span>transform<span class="token punctuation">,</span> biases<span class="token punctuation">)</span> transform <span class="token operator">=</span> tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>transform<span class="token punctuation">,</span> <span class="token punctuation">[</span>batch_size<span class="token punctuation">,</span> K<span class="token punctuation">,</span> K<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">return</span> transform
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
分类网络结构
分类网络内容即PointNet网络结构图中最上面的那个主要框图,即Classification Network。对应文件为“pointnet-master\models\pointnet_cls.py”
def get_model(point_cloud, is_training, bn_decay=None): """ Classification PointNet, input is BxNx3, output Bx40 """ batch_size = point_cloud.get_shape()[0].value #得到训练批量 num_point = point_cloud.get_shape()[1].value #得到点云个数 end_points = {} #生成字典
<span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">'transform_net1'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> sc<span class="token punctuation">:</span> transform <span class="token operator">=</span> input_transform_net<span class="token punctuation">(</span>point_cloud<span class="token punctuation">,</span> is_training<span class="token punctuation">,</span> bn_decay<span class="token punctuation">,</span> K<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span> <span class="token comment">#通过第一个T-Net得到input_tranform</span> point_cloud_transformed <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>point_cloud<span class="token punctuation">,</span> transform<span class="token punctuation">)</span> <span class="token comment">#将原矩阵和input_transform相乘完成转换</span> input_image <span class="token operator">=</span> tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>point_cloud_transformed<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment">#扩展为4维张量[batch_size, num_point, 3, 1]</span> <span class="token comment">#使用64个大小为[1,3]的卷积核得到64个特征通道</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>input_image<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv1'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment">#使用64个大小为[1,1]的卷积核得到64个特征通道</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv2'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment">#最后得到net为[batch_size, point_num, 1, 64]</span> <span class="token comment">#第二次使用T-Net进行转换</span> <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">'transform_net2'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> sc<span class="token punctuation">:</span> transform <span class="token operator">=</span> feature_transform_net<span class="token punctuation">(</span>net<span class="token punctuation">,</span> is_training<span class="token punctuation">,</span> bn_decay<span class="token punctuation">,</span> K<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">)</span> end_points<span class="token punctuation">[</span><span class="token string">'transform'</span><span class="token punctuation">]</span> <span class="token operator">=</span> transform <span class="token comment">#在字典中用transform保存feature transform,记录原始特征</span> <span class="token comment">#squeeze操作删除大小为1的维度</span> <span class="token comment">#注意到之前操作之后得到的net为[batch_size, num_point, 1, 64]</span> <span class="token comment">#从维度0开始,这里指定为第二个维度,squeeze之后为[batch_size, num_point, 64]</span> <span class="token comment">#然后相乘得到变换后的矩阵</span> net_transformed <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>net<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> transform<span class="token punctuation">)</span> <span class="token comment">#指定第二个维度并膨胀为4维张量[batch_size, num_point, 1, 64]</span> net_transformed <span class="token operator">=</span> tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>net_transformed<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#进行3次卷积操作,最后得到[batch_size, num_point, 1, 1024]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net_transformed<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv3'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv4'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv5'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment"># Symmetric function: max pooling</span> <span class="token comment">#max pooling,扫描的模板大小为[num_point, 1],也就是每一个特征通道仅保留一个feature</span> <span class="token comment">#完成该池化操作之后的net表示为[batch_size, 1, 1, 1024]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>max_pool2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_point<span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'maxpool'</span><span class="token punctuation">)</span> <span class="token comment">#重塑张量得到[batch_size, 1024]</span> net <span class="token operator">=</span> tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token punctuation">[</span>batch_size<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#3次全连接并且进行2次dropout</span> <span class="token comment">#最后得到40种分类结果</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>fully_connected<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'fc1'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>net<span class="token punctuation">,</span> keep_prob<span class="token operator">=</span><span class="token number">0.7</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'dp1'</span><span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>fully_connected<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'fc2'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>net<span class="token punctuation">,</span> keep_prob<span class="token operator">=</span><span class="token number">0.7</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'dp2'</span><span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>fully_connected<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">40</span><span class="token punctuation">,</span> activation_fn<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'fc3'</span><span class="token punctuation">)</span> <span class="token comment">#return返回分类结果以及n*64的原始特征</span> <span class="token keyword">return</span> net<span class="token punctuation">,</span> end_points
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
分割网络结构
在PointNet原网络结构图中从global_feature中向下的分支,即Segmentation Network。对应文件为“pointnet-master\models\pointnet_seg.py”。
由于结构与分类网络类似,这里仅简单注释一下。
def get_model(point_cloud, is_training, bn_decay=None): """ Classification PointNet, input is BxNx3, output BxNx50 """ batch_size = point_cloud.get_shape()[0].value num_point = point_cloud.get_shape()[1].value end_points = {}
<span class="token comment">#第一次T-Net转换,转换完成之后膨胀为四维张量</span> <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">'transform_net1'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> sc<span class="token punctuation">:</span> transform <span class="token operator">=</span> input_transform_net<span class="token punctuation">(</span>point_cloud<span class="token punctuation">,</span> is_training<span class="token punctuation">,</span> bn_decay<span class="token punctuation">,</span> K<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span> point_cloud_transformed <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>point_cloud<span class="token punctuation">,</span> transform<span class="token punctuation">)</span> input_image <span class="token operator">=</span> tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>point_cloud_transformed<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment">#两次卷积,完成后得到[batch_size, num_point, 1, 64]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>input_image<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv1'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv2'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment">#第二次T-Net转换,转换过程中注意维度的变换操作</span> <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">'transform_net2'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> sc<span class="token punctuation">:</span> transform <span class="token operator">=</span> feature_transform_net<span class="token punctuation">(</span>net<span class="token punctuation">,</span> is_training<span class="token punctuation">,</span> bn_decay<span class="token punctuation">,</span> K<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">)</span> end_points<span class="token punctuation">[</span><span class="token string">'transform'</span><span class="token punctuation">]</span> <span class="token operator">=</span> transform net_transformed <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>net<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> transform<span class="token punctuation">)</span> point_feat <span class="token operator">=</span> tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>net_transformed<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span>point_feat<span class="token punctuation">)</span> <span class="token comment">#3次卷积操作和1次最大池化操作</span> <span class="token comment">#最后得到[batch_size, 1, 1, 1024]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>point_feat<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv3'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv4'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv5'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> global_feat <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>max_pool2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_point<span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'maxpool'</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span>global_feat<span class="token punctuation">)</span> <span class="token comment">#未进行性tile之前为[batch_size, 1, 1, 1024]</span> <span class="token comment">#tile()函数用来对张量进行扩展,对维度1复制拓展至num_point倍</span> <span class="token comment">#tile()操作之后为[batch_size, num_point, 1, 1024]</span> global_feat_expand <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>global_feat<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> num_point<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment">#tf.concat()将两个矩阵进行拼接,拼接得到网络结构中的n×1088的矩阵,n即num_point,此处忽略batch_size</span> concat_feat <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>point_feat<span class="token punctuation">,</span> global_feat_expand<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span>concat_feat<span class="token punctuation">)</span> <span class="token comment">#5次卷积操作得到128个特征通道输出 [batch_size, num_point, 1, 128]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>concat_feat<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv6'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv7'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv8'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> is_training<span class="token operator">=</span>is_training<span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv9'</span><span class="token punctuation">,</span> bn_decay<span class="token operator">=</span>bn_decay<span class="token punctuation">)</span> <span class="token comment">#max_pooling操作得到50个特征输出</span> <span class="token comment">#[batch_size, num_point, 1, 50]</span> net <span class="token operator">=</span> tf_util<span class="token punctuation">.</span>conv2d<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token number">50</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'VALID'</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> activation_fn<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">'conv10'</span><span class="token punctuation">)</span> <span class="token comment">#squeeze()删除大小为1的第二个维度</span> net <span class="token operator">=</span> tf<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>net<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># BxNxC</span> <span class="token keyword">return</span> net<span class="token punctuation">,</span> end_points
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
以上便是此次全部代码解释内容,有错误请及时留言告知