数据集增强(Data Augmentation)是机器学习常用的数据预处理方法。例如,当手头的数据量太少时,可以人工生成一些有意义的数据用来训练,这种数据获取方法的突出优点是:成本低,效果好。另外,当用来分类的数据集有数据倾斜(skewed data)即某一类样本比另一类多很多时,可以这对样本较少的一类进行数据增强。
在图像领域,常用的数据增强方法有:旋转,镜像,缩放等。
而在激光点云中,常用的数据增强方法有:旋转,加噪声,降采样,不同程度的遮挡等。
这里暂时只考虑旋转和在每个点的坐标XYZ上加高斯噪声等。理论上也可以对回波强度加上噪声,但噪声的方差和均值很难把握,设置不对的话会起到相反的作用,因此这里先不考虑在回波强度上加噪声。事实上,当采用加噪声的方法进行数据增强时,必须仔细选择噪声的[方差]!
具体代码如下:因为我这里主要考虑激光雷达采集到的路面交通对象,所以在旋转时只考虑了绕Z轴旋转。
# -*- coding: utf-8 -*-
#######################################
########## Data Augmentation ##########
#######################################
import numpy as np
###########
# 绕Z轴旋转 #
###########
# point: vector(1*3:x,y,z)
# rotation_angle: scaler 0~2*pi
def rotate_point (point, rotation_angle):
point = np.array(point)
cos_theta = np.cos(rotation_angle)
sin_theta = np.sin(rotation_angle)
rotation_matrix = np.array([[cos_theta, sin_theta, 0],
[-sin_theta, cos_theta, 0],
[0, 0, 1]])
rotated_point = np.dot(point.reshape(-1, 3), rotation_matrix)
return rotated_point
# point = np.array([1,2,3])
# rotated_point = rotate_point(point, 0.1*np.pi)
# print rotated_point
###########
# 在XYZ上加高斯噪声 #
###########
def jitter_point(point, sigma=0.01, clip=0.05):
assert(clip > 0)
point = np.array(point)
point = point.reshape(-1,3)
Row, Col = point.shape
jittered_point = np.clip(sigma * np.random.randn(Row, Col), -1*clip, clip)
jittered_point += point
return jittered_point
# jittered_point = jitter_point(point)
# print jittered_point
###########
# Data Augmentation #
###########
def augment_data(point, rotation_angle, sigma, clip):
return jitter_point(rotate_point(point, rotation_angle), sigma, clip)
point = np.array(point) 这一语句是将point转换为numpy数组,保证输入的List类型也能运行。
point = point.reshape(-1,3) 是将point变为行向量,考虑到输入有可能是列向量。