PointNet&PointNet++源码pointnet_util.py理解

PointNet&PointNet++源码pointnet_util.py理解

源码:https://github.com/yanx27/Pointnet_Pointnet2_pytorch

文件:pointnet_util.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np


# 打印时间
def timeit(tag, t):
    print("{}: {}s".format(tag, time() - t))
    return time()


# pc_normalize为point cloud normalize
# 即将点云数据进行归一化处理
def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)  # 压缩点云数据求得x,y,z的均值
    pc = pc - centroid  # 求得每一点到中点的绝对距离
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))  # 求得离中心点最大距离,最大的标准差
    pc = pc / m  # 归一化,这里使用的是Z-score标准化方法,即为(x-mean)/std
    return pc


# 确定每个点到采样点的距离,用于ball_query过程
def square_distance(src, dst):
    # 由于在训练中数据通常是以Mini-Batch的形式输入的
    # 所以有一个Batch数量的维度为B。
    """
        Input:
            src: source points, [B, N, C]
            dst: target points, [B, M, C]
        Output:
            dist: per-point square distance, [B, N, M]

        N,M 为src,dst的点数

        dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
             = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    # matmul矩阵相乘,2*(xn * xm + yn * ym + zn * zm)
    # 为了保证src和dst矩阵可以相乘,这里涉及到三维矩阵乘法
    # 需要将dst转变一下维度[B, C, M]
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    # xn*xn + yn*yn + zn*zn
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    # xm*xm + ym*ym + zm*zm
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist


# 按照输入的点云数据和索引返回索引的点云数据
def index_points(points, idx):
    """
        Input:
            点云数据points: input points data, [B, N, C]
            点云索引idx: sample index data, [B, D1,...DN]
        Return:
            new_points:, indexed points data, [B, D1,...DN, C]
    """
    # idx为每各样本中所要选取的点的索引
    # 这里输入的点云数据B*2048*3,其中B为Batch_size,样本数
    # 简单来说,这个函数就是要再点云数据中选取每个样本中索引值在idx这个索引数组里面的点
    # idx的长度为4时,则最后输出的为B*4*3,也就是在2048个点中选取在索引值idx的点
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    # view_shape=[B,1]
    view_shape[1:] = [1] * (len(view_shape) - 1)
    # repeat_shape=[B,S]
    repeat_shape = list(idx.shape)
    # repeat_shape=[1,S]
    repeat_shape[0] = 1
    # .view(view_shape)=.view(B,1)
    # .repeat(repeat_shape)=.view(1,S)
    # 综上所述,batch_indices的维度[B,S]
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    # 从points当中取出每个batch_indices对应索引的数据点
    new_points = points[batch_indices, idx, :]
    return new_points


# 最远点采样
def farthest_point_sample(xyz, npoint):

    device = xyz.device
    B, N, C = xyz.shape  # B:Batch_size, N:num_points, C:channel
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)  # n
    distance = torch.ones(B, N).to(device) * 1e10  # 记录某个样本中所有点到某一个点的距离
    # farthest表示当前最远的点,也是随机初始化,范围为0~N,初始化B个
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    # batch_indices始化为0~(B-1)的数组
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        # 更新第i个最远点
        centroids[:, i] = farthest
        # 取出这个最远点的xyz坐标
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        # 计算点集中的所有点到这个最远点的欧式距离
        dist = torch.sum((xyz - centroid) ** 2, -1)
        # 更新distances,记录样本中每个点距离所有已出现的采样点的最小距离
        mask = dist < distance
        distance[mask] = dist[mask]
        # 从更新后的distances矩阵中找出距离最远的点,作为最远点用于下一轮迭代
        farthest = torch.max(distance, -1)[1]
    return centroids


# 寻找球形领域中的点
def query_ball_point(radius, nsample, xyz, new_xyz):
    # 输入中radius为球形领域的半径
    # nsample为每个领域中要采样的点
    # new_xyz为S个球形领域的中心(由最远点采样在前面得出)
    # xyz为所有的点云
    # 输出为每个样本的每个球形领域的nsample个采样点集的索引
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    # sqrdists: [B, S, N] 记录中心点与所有点之间的欧几里德距离
    sqrdists = square_distance(new_xyz, xyz)
    # 找到所有距离大于radius^2的,其group_idx直接置为N;其余的保留原来的值
    group_idx[sqrdists > radius ** 2] = N
    # 做升序排列,前面大于radius^2的都是N,会是最大值,所以会直接在剩下的点中取出前nsample个点
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    # 考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),这种点需要舍弃,直接用第一个点来代替即可
    # group_first: [B, S, k], 实际就是把group_idx中的第一个点的值复制为了[B, S, K]的维度,便利于后面的替换
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    # 找到group_idx中值等于N的点
    mask = group_idx == N
    # 将这些点的值替换为第一个点的值
    group_idx[mask] = group_first[mask]
    return group_idx


# 将整个点云分散成局部的group,对每一个group都可以用PointNet单独的提取局部的全局特征
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
        Input:
            npoint: Number of point for FPS
            radius: Radius of ball query
            nsample: Number of point for each ball query
            xyz: Old feature of points position data, [B, N, C]
            points: New feature of points data, [B, N, D]
        Return:
            new_xyz: sampled points position data, [B, npoint, C]
            new_points: sampled points data, [B, npoint, nsample, C+D]

    """

    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
    torch.cuda.empty_cache()
    # 从原点云中挑出最远点采样的采样点为new_xyz
    new_xyz = index_points(xyz, fps_idx)  # new_xyz代表中心点,此时维度为[B, S, 3]
    torch.cuda.empty_cache()
    # idx:[B, npoint, nsample] 代表npoint个球形区域中每个区域的nsample个采样点的索引
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    torch.cuda.empty_cache()
    # grouped_xyz:[B, npoint, nsample, C]
    # 通过index_points将所有group内的nsample个采样点从原始点中挑出来
    grouped_xyz = index_points(xyz, idx)  # [B, npoint, nsample, C]
    torch.cuda.empty_cache()
    # grouped_xyz减去采样点即中心值
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
    torch.cuda.empty_cache()

    # 如果每个点上面有新的特征的维度,则用新的特征与旧的特征拼接,否则直接返回旧的特征
    if points is not None:
        # 通过index_points将所有group内的nsample个采样点从原始点中挑出来,得到group内点的除坐标维度外的其他维度的数据
        grouped_points = index_points(points, idx)
        # dim=-1代表按照最后的维度进行拼接,即相当于dim=3
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points


# 将整个点云分散成局部的group,对每一个group都可以用PointNet单独的提取局部的全局特征
def sample_and_group_all(xyz, points):
    # 与前面的不同在于:直接将所有点作为一个group,即增加一个长度为1的维度而已
    """
        Input:
            xyz: input points position data, [B, N, C]
            points: input points data, [B, N, D]
        Return:
            new_xyz: sampled points position data, [B, 1, C]
            new_points: sampled points data, [B, 1, N, C+D]

    """
    device = xyz.device
    B, N, C = xyz.shape
    # new_xyz代表中心点,用原点表示
    new_xyz = torch.zeros(B, 1, C).to(device)
    # grouped_xyz减去中心点:每个区域的点减去区域的中心值,由于中心点为原点,所以结果仍然是grouped_xyz

    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        # view(B, 1, N, -1),-1代表自动计算,即结果等于view(B, 1, N, D)
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points


class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        """
        Input:
        npoint: Number of point for FPS sampling
        radius: Radius for ball query
        nsample: Number of point for each ball query
        in_channel: the dimention of channel
        mlp: A list for mlp input-output channel, such as [64, 64, 128]
        group_all: bool type for group_all or not
        """

        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        # 在构造函数__init__中用到list、tuple、dict等对象时,
        # 一定要思考是否应该用ModuleList或ParameterList代替。
        # 如果你想设计一个神经网络的层数作为输入传递
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):

        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)   # permute(dims):将tensor的维度换位,[B, N, 3]

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            # new_xyz:[B, npoint, 3], new_points:[B, npoint, nsample, 3+D]
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)

        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]

        # 利用1x1的2d的卷积相当于把每个group当成一个通道,共npoint个通道,
        # 对[3+D, nsample]的维度上做逐像素的卷积,结果相当于对单个C+D维度做1d的卷积
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        # 对每个group做一个max pooling得到局部的全局特征,得到的new_points:[B,3+D,npoint]
        new_points = torch.max(new_points, 2)[0]
        # new_xyz:[B, 3, npoint]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points


# MSG实现
class PointNetSetAbstractionMsg(nn.Module):
    '''
         PointNet Set Abstraction (SA) module with Multi-Scale Grouping (MSG)
         Input:
             xyz: (batch_size, ndataset, 3) TF tensor
             points: (batch_size, ndataset, channel) TF tensor
             npoint: int32 -- #points sampled in farthest point sampling
             radius_list: list of float32 -- search radius in local region
             nsample_list: list of int32 -- how many points in each local region
             mlp_list: list of list of int32 -- output size for MLP on each point
         Return:
             new_xyz: (batch_size, npoint, 3) TF tensor
             new_points: (batch_size, npoint, sum_k{mlp[k][-1]}) TF tensor
     '''
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3
            for out_channel in mlp_list[i]:
                convs.append(nn.Conv2d(last_channel, out_channel, 1))
                bns.append(nn.BatchNorm2d(out_channel))
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        """
            Input:
                xyz: input points position data, [B, C, N]
                points: input points data, [B, D, N]
            Return:
                new_xyz: sampled points position data, [B, C, S]
                new_points_concat: sample points feature data, [B, D', S]
        """

        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
        new_points_list = []
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz.view(B, S, 1, C)
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat


# 实现主要通过线性差值与MLP堆叠完成,距离越远的点权重越小,最后对于每一个点的权重再做一个全局的归一化
class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        """
            Input:
                xyz1: input points position data, [B, C, N]
                yz2: sampled input points position data, [B, C, S]
                points1: input points data, [B, D, N]
                points2: input points data, [B, D, S]
            Return:
                new_points: upsampled points data, [B, D', N]
        """

        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)

        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)

        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points


  • 5
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

xiaobai_Ry

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值