点云中的数据增强方法
本文基于相机坐标展示(kitti中的标签是相机坐标系)
一 旋转(相机坐标系沿y轴旋转)
import numpy as np
def rotation_points_single_angle(points, angle, axis=0):
# points: [N, 3]
rot_sin = np.sin(angle)
rot_cos = np.cos(angle)
if axis == 1:
rot_mat_T = np.array(
[[rot_cos, 0, -rot_sin], [0, 1, 0], [rot_sin, 0, rot_cos]],
dtype=points.dtype)
elif axis == 2 or axis == -1:
rot_mat_T = np.array(
[[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]],
dtype=points.dtype)
elif axis == 0:
rot_mat_T = np.array(
[[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]],
dtype=points.dtype)
else:
raise ValueError("axis should in range")
return points @ rot_mat_T
def global_rotation(gt_boxes, points):
noise_rotation = np.random.uniform(-90, 90)
points[:, :3] = rotation_points_single_angle(
points[:, :3], noise_rotation, axis=1)
gt_boxes[:, :3] = rotation_points_single_angle(
gt_boxes[:, :3], noise_rotation, axis=1)
gt_boxes[:, 6] += noise_rotation
return gt_boxes, points
输出:
二 镜像
x轴镜像和z轴镜像
import numpy as np
def random_flip(gt_boxes, points, probability=0.5):
x = np.random.choice(
[False, True], replace=False, p=[1 - probability, probability])
z = np.random.choice(
[False, True], replace=False, p=[1 - probability, probability])
if z:
gt_boxes[:, 2] = -gt_boxes[:, 2]
gt_boxes[:, 6] = -gt_boxes[:, 6] + np.pi
points[:, 2] = -points[:, 2]
if x:
gt_boxes[:, 0] = -gt_boxes[:, 0]
gt_boxes[:, 6] = -gt_boxes[:, 6] + np.pi
points[:, 0] = -points[:, 0]
return gt_boxes, points
x轴镜像:
三 真值提取
if __name__=='__main__':
import time
pt #相机坐标系下的点云
dets # 对应标签的八个定点
t = time.time()
x = pt[:, 0]
y= pt[:, 1]
z = pt[:, 2]
dets = det.copy()[0]
xmin = min(dets[:,0])
xmax= max(dets[:, 0])
ymin = min(dets[:, 1])
ymax = max(dets[:, 1])
zmin = min(dets[:, 2])
zmax = max(dets[:, 2])
print(xmin,xmax,ymin,ymax,zmin,zmax)
x_filt = np.logical_and((x > xmin), (x < xmax))
y_filt = np.logical_and((y > ymin), (y < ymax))
z_filt = np.logical_and((z > zmin), (z < zmax))
filter = np.logical_and(x_filt, y_filt)
filter = np.logical_and(filter, z_filt)
indices = np.argwhere(filter).flatten()
nn= pt[indices]
print(time.time()-t)
# show(pt)
转换时间(是其他转换方法的100分之一左右)
0.0018105506896972656