Pointnet++代码详解(六):PointNetSetAbstraction层

普通的SetAbstraction实现的代码较为简单,主要是前面的一些函数的叠加应用。首先先通过sample_and_group的操作形成局部的group,然后对局部的group中的每一个点做MLP操作,最后进行局部的最大池化,得到局部的全局特征。

class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        ''' 
            PointNet Set Abstraction (SA) Module
                Input:
                    xyz: (batch_size, ndataset, 3) <只包含坐标的点>
                    points: (batch_size, ndataset, channel) 
                    <Note: 不包含坐标,只包含了每个点经过之前层后提取的除坐标外的其他特征,所以第一层没有>
                    npoint: int32 -- #points sampled in farthest point sampling 
                    <Sample layer找512个点作为中心点,这个手工选择的,靠经验或者说靠实验>
                    radius: float32 -- search radius in local region
                    <Grouping layer中ball quary的球半径是0.2,注意这是坐标归一化后的尺度>
                    nsample: int32 -- how many points in each local region
                     <围绕每个中心点在指定半径球内进行采样,上限是32个;半径占主导>
                    mlp: list of int32 -- output size for MLP on each point
                    <PointNet layer有多少层,特征维度变化分别是多少,如(64,64,128)>
                    group_all: bool -- group all points into one PC if set true, OVERRIDE
                        npoint, radius and nsample settings
                Return:
                    new_xyz: (batch_size, npoint, 3) TF tensor
                    new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor
                    idx: (batch_size, npoint, nsample) int32 -- indices for local regions
        '''
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        '''
        nn.ModuleList()
        #在构造函数__init__中用到list、tuple、dict等对象时,
        # 一定要思考是否应该用ModuleList或ParameterList代替。
        #如果你想设计一个神经网络的层数作为输入传递
        '''
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(in_channels=last_channel, out_channels=out_channel, kernel_size=1))
            self.mlp_bns.append(nn.BatchNorm2d(num_features=out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        N是输入点的数量,C是坐标维度(C=3),D是特征维度(除坐标维度以外的其他特征维度)
        S是输出点的数量,C是坐标维度,D'是新的特征维度
            Input:
                xyz: input points position data, [B, C, N]
                points: input points data, [B, D, N]
            Return:
                new_xyz: sampled points position data, [B, C, S]
                new_points_concat: sample points feature data, [B, D', S]
        """
        #permute(dims):将tensor的维度换位,[B, N, 3]
        xyz = xyz.permute(0, 2, 1)
        #判断点是否有其他特征维度
        if points is not None:
            #points:[B, N, D]
            points = points.permute(0, 2, 1)
        if self.group_all:
            #new_xyz:[B, 1, 3],FPS sampled points position data 
            #new_points:[B, 1, N, 3+D], sampled points data
            new_xyz, new_points = samle_and_group_all(xyz, points)
        else:
            #new_xyz:[B, npoint, 3], new_points:[B, npoint, nsample, 3+D]
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_points: sampled points data, [B, 3+D, nsample, npoint] OR [B, 3+D, N, 1]
        new_points = new_points.permute(0, 3, 2, 1)
        # 利用1x1的2d的卷积相当于把每个group当成一个通道,共npoint个通道,
        #对[3+D, nsample]的维度上做逐像素的卷积,结果相当于对单个C+D维度做1d的卷积
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        #对每个group做一个max pooling得到局部的全局特征,得到的new_points:[B,3+D,npoint]
        new_points = torch.max(new_points, 2)[0]
        #new_xyz:[B, 3, npoint]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points

nn.Modulelist()

为什么有他?

写一个module然后就写foreword函数很麻烦,所以就有了这两个。它被设计用来存储任意数量的nn. module。

什么时候用?

如果在构造函数__init__中用到list、tuple、dict等对象时,一定要思考是否应该用ModuleList或ParameterList代替。

如果你想设计一个神经网络的层数作为输入传递。

和list的区别?

ModuleList是Module的子类,当在Module中使用它的时候,就能自动识别为子module。

当添加 nn.ModuleList 作为 nn.Module 对象的一个成员时(即当我们添加模块到我们的网络时),所有 nn.ModuleList 内部的 nn.Module 的 parameter 也被添加作为 我们的网络的 parameter。

extend和append方法

nn.moduleList定义对象后,有extend和append方法,用法和python中一样,extend是添加另一个modulelist  append是添加另一个module

class LinearNet(nn.Module):

  def __init__(self, input_size, num_layers, layers_size, output_size):

     super(LinearNet, self).__init__()

     self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])

     self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])

     self.linears.append(nn.Linear(layers_size, output_size)

建立以及使用方法

modellist = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])
input = V(t.randn(1, 3))
for model in modellist:
    input = model(input)
# 下面会报错,因为modellist没有实现forward方法
# output = modelist(input)

和普通list不一样,它和torch的其他机制结合紧密,继承了nn.Module的网络模型class可以使用nn.ModuleList并识别其中的parameters,当然这只是个list,不会自动实现forward方法,

​class MyModule(nn.Module):
    def __init__(self):
    super(MyModule, self).__init__()
    self.list = [nn.Linear(3, 4), nn.ReLU()]
    self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])

    def forward(self):
        pass

model = MyModule()
print(model)
for name, param in model.named_parameters():
    print(name, param.size())

MyModule(
  (module_list): ModuleList(
    (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()))
​
('module_list.0.weight', torch.Size([3, 3, 3, 3]))
('module_list.0.bias', torch.Size([3]))

可见,普通list中的子module并不能被主module所识别,而ModuleList中的子module能够被主module所识别。这意味着如果用list保存子module,将无法调整其参数,因其未加入到主module的参数中。

除ModuleList之外还有ParameterList,其是一个可以包含多个parameter的类list对象。在实际应用中,使用方式与ModuleList类似。
 

 

 

  • 7
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值