复现了一下PointNet++,拿自己做的数据集试试感觉效果挺好,于是开始读读代码,打打基础,顺便找点灵感和思路。代码中有些注释是基于我自己制作的数据集进行解释的,我的数据集仿照的是ShapeNet数据集的格式,在我的数据集中,只有一种物体:book(书),book有两个部分:background(背景)和seam(书缝),分别对应0和1。
ShapeNetDataLoader.py
ShapeNetDataLoader.py作用就是把n个点云数据转换成一个数组,数组有n项,每项包含点的信息,点的大类别(book),点的小类别(background,seam)。
# *_*coding:utf-8 *_*
import os
import json
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
class PartNormalDataset(Dataset):
def __init__(self,root = './data/book_seam_dataset', npoints=50000, split='train', class_choice=None, normal_channel=False):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.normal_channel = normal_channel
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
self.cat = {k: v for k, v in self.cat.items()} # {'book': '12345678'}
self.classes_original = dict(zip(self.cat, range(len(self.cat)))) # {'book': 0}
if not class_choice is None:
self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
# print(self.cat)
self.meta = {}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) # {'1', '2', ...}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
for item in self.cat: # item:'book'
# print('category', item)
self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item])
fns = sorted(os.listdir(dir_point))
# print(fns[0][0:-4])
if split == 'trainval': # 取训练集+验证集
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] # fn[0:-4]就是‘1.txt’里面的‘1’, fns:['1.txt', '10.txt', ...]
elif split == 'train':
fns = [fn for fn in fns if fn[0:-4] in train_ids]
elif split == 'val':
fns = [fn for fn in fns if fn[0:-4] in val_ids]
elif split == 'test':
fns = [fn for fn in fns if fn[0:-4] in test_ids]
else:
print('Unknown split: %s. Exiting..' % (split))
exit(-1)
# print(os.path.basename(fns))
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0]) # os.path.basename删除目录名,保留文件名, token:'1'
self.meta[item].append(os.path.join(dir_point, token + '.txt')) # {'book': ['data/book_seam_datas...5678/1.txt','...',...}
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn))
self.classes = {}
for i in self.cat.keys():
self.classes[i] = self.classes_original[i]
# Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
self.seg_classes = {'book': [0, 1]}
# for cat in sorted(self.seg_classes.keys()):
# print(cat, self.seg_classes[cat])
self.cache = {} # from index to (point_set, cls, seg) tuple
self.cache_size = 20000 # 缓存点的数据,采样点最多不能超过缓存点数量最大值(20000)
def __getitem__(self, index):
if index in self.cache:
point_set, cls, seg = self.cache[index]
else:
fn = self.datapath[index] # ('book', 'data/book_seam_datas...5678/5.txt')
cat = self.datapath[index][0] # 'book'
cls = self.classes[cat] # [0]
cls = np.array([cls]).astype(np.int32)
data = np.loadtxt(fn[1]).astype(np.float32)
if not self.normal_channel:
point_set = data[:, 0:3]
else:
point_set = data[:, 0:6]
seg = data[:, -1].astype(np.int32)
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls, seg)
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
choice = np.random.choice(len(seg), self.npoints, replace=True)
# resample
point_set = point_set[choice, :]
seg = seg[choice]
return point_set, cls, seg # 点的信息,点的大类别(book),点的小类别(background,seam)
def __len__(self):
return len(self.datapath)
pointnet2_part_seg_msg.py
pointnet2_part_seg_msg.py是整个网络的整体框架,通过调用pointnet_utils.py中的自定义网络模型进行一层一层的搭建,因此看不到网络的具体细节。
pointnet2_utils.py
pointnet2_utils.py中定义了各类网络模型以及pointnet++的关键算法。
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
B:batchsize, N:第一组点个数, M:第二组点个数, C:输入点通道数(xyz.C=3)
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
batchsize个[N,M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) # permute:转换维度
dist += torch.sum(src ** 2, -1).view(B, N, 1) # view:按维度填充
dist += torch.sum(dst ** 2, -1).view(B, 1, M) # 数组广播机制,右边的式子复制N组后与dist叠加
return dist
def index_points(points, idx): # i按照输入的点云数据和索引返回由索引的点云数据。
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape) #view_shape=[B,S]
view_shape[1:] = [1] * (len(view_shape) - 1) #[1] * (len(view_shape) - 1) -> [1],即view_shape=[B,1]
repeat_shape = list(idx.shape) #repeat_shape=[B,S]
repeat_shape[0] = 1 #repeat_shape=[1,S]
#.view(view_shape)=.view(B,1)
#.repeat(repeat_shape)=.view(1,S)
#batch_indices的维度[B,S]
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint):
'''
FPS的逻辑如下:
假设一共有n个点,整个点集为N = {f1, f2,…,fn}, 目标是选取n1个起始点做为下一步的中心点:
随机选取一个点fi为起始点,并写入起始点集 B = {fi};
选取剩余n-1个点计算和fi点的距离,选择最远点fj写入起始点集B={fi,fj};
选取剩余n-2个点计算和点集B中每个点的距离, 将最短的那个距离作为该点到点集的距离, 这样得到n-2个到点集的距离,选取最远的那个点写入起始点B = {fi, fj ,fk},同时剩下n-3个点, 如果n1=3 则到此选择完毕;
如果n1 > 3则重复上面步骤直到选取n1个起始点为止.
'''
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape # B:BatchSize, N:ndataset(点云中点的个数), C:dimension
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) # 提取得到中心点的集合
distance = torch.ones(B, N).to(device) * 1e10 # 记录某个样本中所有点到某一个点的距离,先取很大
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) # 当前最远的点,随机初始化,范围为0~N,初始化B个,对应到每个样本都随机有一个初始最远点,B列的行向量
batch_indices = torch.arange(B, dtype=torch.long).to(device) # batch的索引,0~(B-1)的数组
for i in range(npoint):
centroids[:, i] = farthest # 第i个最远点
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) # 取出最远点xyz坐标
dist = torch.sum((xyz - centroid) ** 2, -1) # 计算距离,-1代表行求和
mask = dist < distance # 一个bool值的张量数组
distance[mask] = dist[mask] # True的会留下,False删除
farthest = torch.max(distance, -1)[1] # 返回一个张量,第一项是最大值,第二项是索引,-1代表列索引
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
'''
'''
"""
Input:
radius: local region radius # radius为半径,new_xyz为中心,取nsample个点
nsample: max sample number in local region
xyz: all points, [B, N, 3] # 所有点
new_xyz: query points, [B, S, 3] # farthest_point_sample得到S个中心点, new_xyz为中心点xyz
Return:
group_idx: grouped points index, [B, S, nsample] # nsameple个点的索引
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # torch.arange得到索引,view转换为三维,repeat使其复制成[B,S,N]
sqrdists = square_distance(new_xyz, xyz) # 计算中心点与所有点之间的欧几里德距离
group_idx[sqrdists > radius ** 2] = N # 大于半径的点设置成N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] # 做升序排列,前面大于radius^2的都是N,会是最大值,所以会直接在剩下的点中取出前nsample个点. 0代表输出值,1代表索引
# 考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),这种点需要舍弃,直接用第一个点来代替即可
# group_first: [B, S, nsample], 实际就是把group_idx中的第一个点的值复制到[B, S, nsample]的维度,便利于后面的替换
# 这里要用view是因为group_idx[:, :, 0]取出之后的tensor相当于二维Tensor,因此需要用view变成三维tensor
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
# 找到group_idx中值等于N的点,会输出0,1构成的三维Tensor,维度为[B,S,nsample]
mask = group_idx == N
# 将这些点的值替换为第一个点的值
group_idx[mask] = group_first[mask]
return group_idx
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 中心点
new_xyz = index_points(xyz, fps_idx) # 中心点位置
idx = query_ball_point(radius, nsample, xyz, new_xyz) # 球查询得到点的索引
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] # 球查询点的位置
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) # 计算与中心点距离
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] C=3,D为点的特征维度(位置、法向、颜色)
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C) #new_xyz代表中心点,用原点表示
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) # MLP就相当于是1x1卷积
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
N是输入点的数量,C是坐标维度(C=3),D是特征维度(除坐标维度以外的其他特征维度)
S是输出点的数量,C是坐标维度,D'是新的特征维度
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1) # [B, N, 3]
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] # pytorch的通道顺序是NCHW
# N - Batch
# C - Channel
# H - Height
# W - Width
# 对[3+D, nsample]的维度上做逐像素的卷积,结果相当于对单个C+D维度做1d的卷积
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0]
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
new_points_list = []
# 针对多个radius和nsample取点
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, mlp):
super(PointNetFeaturePropagation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
"""
Input:
xyz1: input points position data, [B, C, N] # 所有点
xyz2: sampled input points position data, [B, C, S] # 采样点
points1: input points data, [B, D, N]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
" 将B C N 转换为B N C 然后利用插值将高维点云数目S 插值到低维点云数目N (N大于S)"
" xyz1 低维点云 数量为N xyz2 高维点云 数量为S"
xyz1 = xyz1.permute(0, 2, 1)
xyz2 = xyz2.permute(0, 2, 1)
points2 = points2.permute(0, 2, 1)
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
"如果最后只有一个点,就将S直复制N份后与与低维信息进行拼接"
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2) # [B,N,S]
dists, idx = dists.sort(dim=-1) # 找到距离最近的三个邻居
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B,N,3],N个点与这S个距离最近的前三个点的索引
dist_recip = 1.0 / (dists + 1e-8) # 求距离的倒数 2,512,3 对应论文中的 Wi(x)
norm = torch.sum(dist_recip, dim=2, keepdim=True) # 也就是将距离最近的三个邻居的加起来 此时对应论文中公式的分母部分
weight = dist_recip / norm
"""
这里的weight是计算权重 dist_recip中存放的是三个邻居的距离 norm中存放是距离的和
两者相除就是每个距离占总和的比重 也就是weight
"""
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) # 点乘
if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points
总的来讲,这个文件主要实现的是两个网络结构:PointNetSetAbstraction和PointNetFeaturePropagation,
PointNetSetAbstractionMsg只是PointNetSetAbstraction使用多个采样半径后叠加的结果。
文中的注释参考了博主weixin_42707080的PointNet++系列文章和正在学习的浅语的文章《PointNet++上采样(Feature Propagation)》。