17.5 准备数据文件
在“data”目录中实现了多个程序文件,主要用于处理、收集和创建用于机器学习和模型训练的数据。这些程序文件旨在支持CARLA仿真环境中数据的采集、整理、转换和存储,以便进行自动驾驶相关任务的数据预处理和模型训练。这些功能对于开发和测试自动驾驶算法和模型非常有用。
17.5.1 CARLA数据处理与转换
编写文件data/utils.py,功能是提供了处理CARLA模拟器中的图像和数据的方法,例如将图像从CARLA原始格式转换为numpy数组,执行语义分割标签的映射,计算车辆距离车道中心的距离等。这些函数用于处理和转换与自动驾驶和模拟有关的数据。具体实现代码如下所示。
import math
import numpy
def to_bgra_array(image):
"""将CARLA原始图像转换为BGRA numpy数组。"""
array = numpy.frombuffer(image.raw_data, dtype=numpy.dtype("uint8"))
array = numpy.reshape(array, (image.height, image.width, 4))
return array
def to_rgb_array(image):
"""将CARLA原始图像转换为RGB numpy数组。"""
array = to_bgra_array(image)
# 将BGRA转换为RGB。
array = array[:, :, :3]
array = array[:, :, ::-1]
return array
def labels_to_array(image):
"""将包含CARLA语义分割标签的图像转换为包含每个像素的标签的2D数组。"""
return to_bgra_array(image)[:, :, 2]
def labels_to_cityscapes_palette(image):
"""将包含CARLA语义分割标签的图像转换为Cityscapes调色板。"""
classes = {
0: [0, 0, 0], # 无
1: [70, 70, 70], # 建筑物
2: [190, 153, 153], # 围栏
3: [72, 0, 90], # 其他
4: [220, 20, 60], # 行人
5: [153, 153, 153], # 杆子
6: [157, 234, 50], # 路线
7: [128, 64, 128], # 道路
8: [244, 35, 232], # 人行道
9: [107, 142, 35], # 植被
10: [0, 0, 255], # 车辆
11: [102, 102, 156], # 墙
12: [220, 220, 0] # 交通标志
}
array = labels_to_array(image)
result = numpy.zeros((array.shape[0], array.shape[1], 3))
for key, value in classes.items():
result[numpy.where(array == key)] = value
return result
def depth_to_array(image):
"""将包含CARLA编码深度图的图像转换为每个像素深度值归一化在[0.0, 1.0]之间的2D数组。"""
array = to_bgra_array(image)
array = array.astype(numpy.float32)
# 应用(R + G * 256 + B * 256 * 256) / (256 * 256 * 256 - 1)。
normalized_depth = numpy.dot(array[:, :, :3], [65536.0, 256.0, 1.0])
normalized_depth /= 16777215.0 # (256.0 * 256.0 * 256.0 - 1.0)
return normalized_depth.astype(numpy.float32)
def depth_to_logarithmic_grayscale(normalized_depth):
"""将包含CARLA编码深度图的图像转换为对数灰度图像数组。
使用 "max_depth" 用于排除远处的点。
"""
# 转换为对数深度。
logdepth = numpy.ones(normalized_depth.shape) + \
(numpy.log(normalized_depth) / 5.70378)
logdepth = numpy.clip(logdepth, 0.0, 1.0)
logdepth *= 255.0
# 扩展为三个颜色通道。
return logdepth.astype(numpy.uint8)
def distance_from_center(previous_wp, current_wp, car_loc):
"""计算车辆距离车道中心的距离。
Args:
previous_wp: 上一个路标
current_wp: 当前路标
car_loc: 车辆位置
"""
prev_x = previous_wp.transform.location.x
prev_y = previous_wp.transform.location.y
curr_x = current_wp.transform.location.x
curr_y = current_wp.transform.location.y
car_x = car_loc.x
car_y = car_loc.y
a = curr_y - prev_y
b = -(curr_x - prev_x)
c = (curr_x - prev_x) * prev_y - (curr_y - prev_y) * prev_x
d = abs(a * car_x + b * car_y + c) / (a ** 2 + b ** 2 + 1e-6) ** 0.5
return d
def low_resolution_semantics(image):
"""将CARLA语义图像(29个类别)转换为低分辨率语义分割图像(14个类别)。
警告:图像将被覆盖。"""
mapping = {0: 0, # 无
1: 1, # 道路
2: 2, # 人行道
3: 3, # 建筑物
4: 4, # 墙
5: 5, # 围栏
6: 6, # 杆子
7: 7, # 交通灯
8: 7, # 交通标志
9: 8, # 植被
10: 9, # 地形 -> 其他
11: 10, # 天空(已添加)
12: 11, # 行人
13: 11, # 骑手 -> 行人
14: 12, # 车辆 -> 车辆
15: 12, # 卡车 -> 车辆
16: 12, # 公共汽车 -> 车辆
17: 12, # 列车 -> 车辆
18: 12, # 摩托车 -> 车辆
19: 12, # 自行车 -> 车辆
20: 9, # 静态 -> 其他
21: 9, # 动态 -> 其他
22: 9, # 其他 -> 其他
23: 9, # 水 -> 其他
24: 13, # 道路线
25: 9, # 地面 -> 其他
26: 9, # 桥 -> 其他
27: 9, # 铁路 -> 其他
28: 9 # 护栏 -> 其他
}
for i in range(8, 29):
image[image == i] = mapping[i]
def dist_to_roadline(carla_map, vehicle):
curr_loc = vehicle.get_transform().location
yaw = vehicle.get_transform().rotation.yaw
waypoint = carla_map.get_waypoint(curr_loc)
waypoint_yaw = waypoint.transform.rotation.yaw
yaw_diff = yaw - waypoint_yaw
yaw_diff_rad = yaw_diff / 180 * math.pi
bb = vehicle.bounding_box
corners = bb.get_world_vertices(vehicle.get_transform())
dis_to_left, dis_to_right = 100, 100
for corner in corners:
if corner.z < 1:
waypt = carla_map.get_waypoint(corner)
waypt_transform = waypt.transform
waypoint_vec_x = waypt_transform.location.x - corner.x
waypoint_vec_y = waypt_transform.location.y - corner.y
dis_to_waypt = math.sqrt(waypoint_vec_x ** 2 + waypoint_vec_y ** 2)
waypoint_vec_angle = math.atan2(waypoint_vec_y, waypoint_vec_x) * 180 / math.pi
angle_diff = waypoint_vec_angle - waypt_transform.rotation.yaw
if (angle_diff > 0 and angle_diff < 180) or (angle_diff > -360 and angle_diff < -180):
dis_to_left = min(dis_to_left, waypoint.lane_width / 2 - dis_to_waypt)
dis_to_right = min(dis_to_right, waypoint.lane_width / 2 + dis_to_waypt)
else:
dis_to_left = min(dis_to_left, waypoint.lane_width / 2 + dis_to_waypt)
dis_to_right = min(dis_to_right, waypoint.lane_width / 2 - dis_to_waypt)
return dis_to_left, dis_to_right, math.sin(yaw_diff_rad), math.cos(yaw_diff_rad)
对上述代码的具体说明如下:
- to_bgra_array(image):将CARLA原始图像转换为BGRA numpy数组。
- to_rgb_array(image):将CARLA原始图像转换为RGB numpy数组。
- labels_to_array(image):将包含CARLA语义分割标签的图像转换为包含每个像素的标签的2D数组。
- labels_to_cityscapes_palette(image):将包含CARLA语义分割标签的图像转换为Cityscapes调色板。
- depth_to_array(image):将包含CARLA编码深度图的图像转换为每个像素深度值归一化在[0.0, 1.0]之间的2D数组。
- depth_to_logarithmic_grayscale(normalized_depth):将包含CARLA编码深度图的图像转换为对数灰度图像数组。
- distance_from_center(previous_wp, current_wp, car_loc):计算车辆距离车道中心的距离。
- low_resolution_semantics(image):将CARLA语义图像(29个类别)转换为低分辨率语义分割图像(14个类别)。
- dist_to_roadline(carla_map, vehicle):计算车辆距离车道线的距离。
上述函数用于处理CARLA模拟器中的图像数据、语义分割数据、深度数据以及计算车辆与道路中心或道路线之间的距离等操作。它们提供了用于数据处理和转换的工具。
17.5.2 加载、处理数据
编写文件data/dataset.py,定义了一个名为 AutoencoderDataset 的类,这是一个自定义的PyTorch数据集,用于加载和处理与自动编码器训练相关的数据。具体实现代码如下所示。
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision import transforms
from .utils import low_resolution_semantics
class AutoencoderDataset(Dataset):
def __init__(self, data, resize=None, normalize=True, low_sem=True,
use_img_as_output=False, normalize_output=False):
self.images = []
self.semantics = []
self.data = []
self.junctions = []
i = 0
for _, rgb, depth_image, semantic_image, additional, junction in data:
if i%200==0:
print("Loading data: ", i)
image = torch.cat((read_image(rgb), read_image(depth_image)), dim=0)
semantic = read_image(semantic_image)
if resize:
image = transforms.Resize(resize)(image)
semantic = transforms.Resize(resize, transforms.InterpolationMode.NEAREST)(semantic)
if low_sem:
low_resolution_semantics(semantic)
semantic = semantic.squeeze()
data = torch.FloatTensor(additional)
self.images.append(image)
self.semantics.append(semantic)
self.data.append(data)
self.junctions.append(junction)
i += 1
self.normalize = normalize
self.use_img_as_output = use_img_as_output
self.normalize_output = normalize_output
self.low_sem = low_sem
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx].to(torch.float32)
if self.use_img_as_output:
output = torch.clone(image)
if self.normalize_output:
output /= 255.0
else:
output = self.semantics[idx].to(torch.long)
if self.normalize_output:
output = output.to(torch.float32)
output = output / 13. if self.low_sem else output / 28.
data = self.data[idx]
junction = torch.FloatTensor([self.junctions[idx]])
if self.normalize:
image /= 255.0
return image, output, data, junction
对上述代码的具体说明如下;
(1)函数__init__():用于初始化数据集对象,各个参数的具体说明如下:
- data:包含数据的列表,每个数据包括RGB图像、深度图像、语义分割图像、附加数据和交叉口信息。
- resize:一个整数,用于指定是否调整图像大小。
- normalize:一个布尔值,指示是否对图像进行归一化。
- low_sem:一个布尔值,指示是否将语义分割图像转换为低分辨率。
- use_img_as_output:一个布尔值,指示是否将图像用作输出,而不是语义分割图像。
- normalize_output:一个布尔值,指示是否对输出进行归一化。
(2)函数__len__(self):用于返回数据集中的样本数量。
(3)函数__getitem__(self, idx):根据给定的索引 idx 返回数据集中的一个样本。样本包括图像、输出、附加数据和交叉口信息。根据参数设置,它可能会对图像和输出进行归一化和处理。
由此可见,类AutoencoderDataset主要用于加载和处理用于自动编码器训练的数据集,包括RGB图像、深度图像、语义分割图像以及与之相关的其他信息。
17.5.3 加载、处理数据
编写文件data/collect_data.py,功能是收集自动驾驶训练数据的脚本。通过在CARLA模拟器中生成车辆和传感器来模拟真实道路情况,然后捕获传感器数据(如RGB图像、深度图像、语义分割图像、障碍物距离等)并保存这些数据。具体实现流程如下:
(1)在类Observation中创建 __init__ 函数,功能是创建一个 Observation 对象,用于保存观测数据。然后分别初始化RGB、深度、语义分割图像以及障碍物距离等信息。具体实现代码如下所示。
class Observation:
def __init__(self, save_path, save_np=False):
self.rgb = None
self.depth = None
self.semantic = None
self.obstacle_dist = (0, 25) #(Frame detected, distance)
self.save_path = save_path
self.save_np = save_np
(2)定义函数save_data,用于保存收集到的传感器数据,包括RGB图像、深度图像、语义分割图像,并根据保存方式(图片或Numpy数组)将数据保存到指定路径。具体实现代码如下所示。
def save_data(self, frame):
if not self.save_np:
cv2.imwrite(self.save_path + '/camera/%08d.png' % frame, cv2.cvtColor(self.rgb, cv2.COLOR_RGB2BGR))
cv2.imwrite(self.save_path + '/depth/%08d.png' % frame, self.depth)
cv2.imwrite(self.save_path + '/semantic/%08d.png' % frame, self.semantic)
else:
np.save(self.save_path + '/camera/%08d' % frame, self.rgb)
np.save(self.save_path + '/depth/%08d' % frame, self.depth)
np.save(self.save_path + '/semantic/%08d' % frame, self.semantic)
(3)定义函数camera_callback,用于处理摄像头传感器的回调函数,将传感器数据转换为RGB图像并保存在 Observation 对象中。具体实现代码如下所示。
def camera_callback(img, obs):
if img.raw_data is not None:
array = to_rgb_array(img)
obs.rgb = array
(4)定义函数depth_callback,用于处理深度传感器的回调函数,将传感器数据转换为对数灰度深度图像并保存在 Observation 对象中。具体实现代码如下所示。
def depth_callback(img, obs):
if img.raw_data is not None:
array = depth_to_logarithmic_grayscale(depth_to_array(img))
obs.depth = array
(5)定义函数semantic_callback 函数,用于处理语义分割传感器的回调函数,将传感器数据转换为语义分割图像并保存在 Observation 对象中。具体实现代码如下所示。
def semantic_callback(img, obs):
if img.raw_data is not None:
array = labels_to_array(img)
obs.semantic = array
(6)定义函数obstacle_callback,用于处理障碍物传感器的回调函数,记录障碍物的距离和触发帧数。具体实现代码如下所示。
def obstacle_callback(event, obs):
frame = event.frame
if 'vehicle' in event.other_actor.type_id:
obs.obstacle_dist = (frame, event.distance)
else:
obs.obstacle_dist = (frame, 25.)
(7)定义函数get_spectator_transform 函数,功能是根据车辆的位置和方向计算观察者(spectator)的位置和方向,用于在模拟中观察自动驾驶车辆。具体实现代码如下所示。
def get_spectator_transform(vehicle_transform, d=7):
vehicle_transform.location.z += 3
vehicle_transform.location.x += -d*math.cos(math.radians(vehicle_transform.rotation.yaw))
vehicle_transform.location.y += -d*math.sin(math.radians(vehicle_transform.rotation.yaw))
vehicle_transform.rotation.pitch += -20
return vehicle_transform
(8)定义函数setup_sensors,功能是初始化各种传感器,包括摄像头、深度传感器、语义分割传感器和障碍物传感器,并将它们附加到自动驾驶车辆上。具体实现代码如下所示。
def setup_sensors(ego_vehicle, blueprint_library, obs, world):
sensors = []
#Create sensors
camera_bp = blueprint_library.find('sensor.camera.rgb')
camera_bp.set_attribute('image_size_x', str(CAM_WIDTH))
camera_bp.set_attribute('image_size_y', str(CAM_HEIGHT))
camera_bp.set_attribute('fov', str(CAM_FOV))
camera_bp.set_attribute('sensor_tick', str(TICK))
camera_transform = carla.Transform(carla.Location(x=CAM_POS_X, y=CAM_POS_Y , z=CAM_POS_Z),
carla.Rotation(pitch=CAM_PITCH, yaw=CAM_YAW, roll=CAM_ROLL))
camera = world.spawn_actor(camera_bp, camera_transform, attach_to=ego_vehicle)
camera.listen(lambda data: camera_callback(data, obs))
sensors.append(camera)
depth_bp = blueprint_library.find('sensor.camera.depth')
depth_bp.set_attribute('image_size_x', str(CAM_WIDTH))
depth_bp.set_attribute('image_size_y', str(CAM_HEIGHT))
depth_bp.set_attribute('fov', str(CAM_FOV))
depth_bp.set_attribute('sensor_tick', str(TICK))
depth_transform = carla.Transform(carla.Location(x=CAM_POS_X, y=CAM_POS_Y , z=CAM_POS_Z),
carla.Rotation(pitch=CAM_PITCH, yaw=CAM_YAW, roll=CAM_ROLL))
depth = world.spawn_actor(depth_bp, depth_transform, attach_to=ego_vehicle)
depth.listen(lambda data: depth_callback(data, obs))
sensors.append(depth)
semantic_bp = blueprint_library.find('sensor.camera.semantic_segmentation')
semantic_bp.set_attribute('image_size_x', str(CAM_WIDTH))
semantic_bp.set_attribute('image_size_y', str(CAM_HEIGHT))
semantic_bp.set_attribute('fov', str(CAM_FOV))
semantic_bp.set_attribute('sensor_tick', str(TICK))
semantic_transform = carla.Transform(carla.Location(x=CAM_POS_X, y=CAM_POS_Y , z=CAM_POS_Z),
carla.Rotation(pitch=CAM_PITCH, yaw=CAM_YAW, roll=CAM_ROLL))
semantic = world.spawn_actor(semantic_bp, semantic_transform, attach_to=ego_vehicle)
semantic.listen(lambda data: semantic_callback(data, obs))
sensors.append(semantic)
obstacle_bp = blueprint_library.find('sensor.other.obstacle')
obstacle_bp.set_attribute('only_dynamics', 'False')
obstacle_bp.set_attribute('distance', '20')
obstacle_bp.set_attribute('sensor_tick', str(TICK))
obstacle_transform = carla.Transform()
obstacle = world.spawn_actor(obstacle_bp, obstacle_transform, attach_to=ego_vehicle)
obstacle.listen(lambda data: obstacle_callback(data, obs))
sensors.append(obstacle)
return sensors
(9)定义函数main,这是主要的执行函数,用于创建CARLA客户端、加载CARLA世界、生成车辆和传感器、模拟车辆运行、收集传感器数据以及保存数据。main 函数的核心功能是在CARLA模拟环境中模拟自动驾驶车辆的行驶,实时收集传感器数据,并将数据保存到指定的目录中。整个过程是一个连续的模拟任务,通过多个episode执行,每个episode都包含自动驾驶车辆的行驶和数据收集。具体实现代码如下所示。
def main(args):
#Create output directory
os.makedirs(args.out_folder, exist_ok=True)
camera_path = os.path.join(args.out_folder, 'camera')
depth_path = os.path.join(args.out_folder, 'depth')
semantic_path = os.path.join(args.out_folder, 'semantic')
os.makedirs(camera_path, exist_ok=True)
os.makedirs(depth_path, exist_ok=True)
os.makedirs(semantic_path, exist_ok=True)
weather = getattr(carla.WeatherParameters, args.weather, carla.WeatherParameters.ClearNoon)
tfm_port = args.world_port + 1
sensors = []
vehicles = []
ego_vehicle = None
original_settings = None
if os.path.exists(f"{args.out_folder}/info.pkl"):
with open(f"{args.out_folder}/info.pkl", 'rb') as f:
info = pickle.load(f)
else:
info = {}
try:
#Connect client to server
client = carla.Client(args.host, args.world_port)
client.set_timeout(60.0)
#Load world
world = client.load_world(args.map)
original_settings = world.get_settings()
world.set_weather(weather)
settings = world.get_settings()
settings.synchronous_mode = True
settings.fixed_delta_seconds = TICK
world.apply_settings(settings)
blueprint_library = world.get_blueprint_library()
spawn_points = world.get_map().get_spawn_points()
carla_map = world.get_map()
spectator = world.get_spectator()
#Set traffic manager
traffic_manager = client.get_trafficmanager(tfm_port)
traffic_manager.set_synchronous_mode(True)
traffic_manager.set_hybrid_physics_mode(False)
#Vehicle blueprint
vehicle_bp = blueprint_library.find(EGO_BP)
#Routes
routes = [73, 14, 61, 241, 1, 167, 71, 233, 176, 39]
spawns = [[147, 190, 146, 192, 240, 143, 195, 196, 197, 17, 14, 110, 111, 117, 201, 203, 115, 13, 58],
[11, 115, 95, 13, 10, 96, 58, 60, 114, 61, 232, 238, 8],
[114, 113, 235, 230, 238, 112, 201, 207, 102, 204, 117],
[88, 226, 228, 229, 231, 102, 239, 101, 114, 238, 76, 104],
[75, 105, 99, 106, 77, 200, 168, 107, 3, 139, 167, 18],
[92, 105, 200, 222, 77, 106, 1, 134, 99, 107],
[221, 220, 223, 68, 97, 224, 119],
[82, 78, 42, 40, 38, 48, 126, 174, 152, 90],
[152, 163, 63, 130, 234, 236, 65],
[79, 37, 90, 36, 34, 29, 31]]
for episode in range(args.begin, args.nb_passes*len(routes)):
print("Episode %d" % episode)
route_id = episode % len(routes)
ego_vehicle = world.try_spawn_actor(vehicle_bp, spawn_points[routes[route_id]])
traffic_manager.set_desired_speed(ego_vehicle, 35.)
traffic_manager.ignore_lights_percentage(ego_vehicle, 100.)
traffic_manager.random_left_lanechange_percentage(ego_vehicle, 0)
traffic_manager.random_right_lanechange_percentage(ego_vehicle, 0)
traffic_manager.auto_lane_change(ego_vehicle, False)
traffic_manager.distance_to_leading_vehicle(ego_vehicle, 1.0)
ego_vehicle.set_autopilot(True, tfm_port)
for spawn_id in spawns[route_id]:
if np.random.uniform(0., 1.0) < 0.7:
bp = blueprint_library.find(random.choice(EXO_BP))
vehicle = world.try_spawn_actor(bp, spawn_points[spawn_id])
if vehicle is not None:
traffic_manager.set_desired_speed(vehicle, 20.)
traffic_manager.ignore_lights_percentage(vehicle, 25.)
vehicle.set_autopilot(True, tfm_port)
vehicles.append(vehicle)
obs = Observation(args.out_folder, args.np)
sensors = setup_sensors(ego_vehicle, blueprint_library, obs, world)
for _ in range(10):
world.tick()
noise_steps = 0
for i in range(args.nb_frames):
control = ego_vehicle.get_control()
if i % 28 == 0:
noise_steps = 0
if noise_steps <= 7:
control.steer += np.random.normal(0.0, 0.05)
noise_steps += 1
ego_vehicle.apply_control(control)
w_frame = world.tick()
spectator.set_transform(get_spectator_transform(ego_vehicle.get_transform()))
car_loc = ego_vehicle.get_location()
current_wp = carla_map.get_waypoint(car_loc)
distance_center = distance_from_center(current_wp, current_wp.next(0.001)[0], car_loc)
d_left, d_right, _, _ = dist_to_roadline(carla_map, ego_vehicle)
d_roadline = min(d_left, d_right)
ego_yaw = ego_vehicle.get_transform().rotation.yaw
wp_yaw = current_wp.transform.rotation.yaw
if ego_yaw < 0:
ego_yaw += 360
if wp_yaw < 0:
wp_yaw += 360
yaw_diff = abs(ego_yaw - wp_yaw)
if yaw_diff > 180:
yaw_diff = 360 - yaw_diff
speed = ego_vehicle.get_velocity()
speed = math.sqrt(speed.x**2 + speed.y**2 + speed.z**2)
obstacle_dist = obs.obstacle_dist[1] if abs(obs.obstacle_dist[0]-w_frame) < 2 else 25.
if i%args.freq_save == 0:
info[w_frame] = [obstacle_dist, distance_center, yaw_diff, d_roadline, current_wp.is_junction, speed]
obs.save_data(w_frame)
with open(f"{args.out_folder}/info.pkl", 'wb') as f:
pickle.dump(info, f)
for sensor in sensors:
sensor.stop()
if sensor.is_alive:
sensor.destroy()
for vehicle in vehicles:
if vehicle.is_alive:
vehicle.set_autopilot(False, tfm_port)
vehicle.destroy()
if ego_vehicle is not None:
ego_vehicle.set_autopilot(False, tfm_port)
ego_vehicle.destroy()
sensors = []
ego_vehicle = None
vehicles = []
finally:
for sensor in sensors:
sensor.stop()
sensor.destroy()
if ego_vehicle is not None:
ego_vehicle.set_autopilot(False, tfm_port)
ego_vehicle.destroy()
for vehicle in vehicles:
if vehicle.is_alive:
vehicle.set_autopilot(False, tfm_port)
vehicle.destroy()
if original_settings is not None:
world.apply_settings(original_settings)
17.5.4 创建数据集
编写文件data/create_dataset.py,功能是将不同文件夹中的数据整理成一个数据集文件。这个文件通常是在数据收集和数据预处理后使用的,它可以处理多个文件夹中的数据,包括相机图像、深度图像、语义分割图像以及附加信息,并将它们保存到一个统一的数据集文件中。具体实现代码如下所示。
import os
import pickle
from argparse import ArgumentParser
if __name__=='__main__':
parser = ArgumentParser()
parser.add_argument('--out_file', type=str, default='dataset.pkl', help='output file name')
parser.add_argument('--folders', type=str, nargs='+', help='folders to be processed', required=True)
parser.add_argument('--idxs', type=int, nargs='+', default=None,
help='folders class index (separate different weather conditions)')
parser.add_argument('-np', action='store_true', help='Camera images saved as .npy')
args = parser.parse_args()
if args.idxs is None:
args.idxs = list(range(len(args.folders)))
else:
assert len(args.folders) == len(args.idxs), 'Number of folders and idxs must be the same'
if not args.out_file.endswith('.pkl'):
args.out_file += '.pkl'
final_data = []
for idx, folder in zip(args.idxs, args.folders):
if not os.path.exists(folder):
raise ValueError(f'Folder {folder} does not exist')
print(f'Processing {folder}...')
folder = os.path.abspath(folder)
camera = os.path.join(folder, 'camera')
depth = os.path.join(folder, 'depth')
semantic = os.path.join(folder, 'semantic')
with open(os.path.join(folder, 'info.pkl'), 'rb') as f:
data = pickle.load(f)
for frame, d in data.items():
image = os.path.join(camera, f'{frame:08d}.png') if not args.np else os.path.join(camera, f'{frame:08d}.npy')
depth_image = os.path.join(depth, f'{frame:08d}.png') if not args.np else os.path.join(depth, f'{frame:08d}.npy')
semantic_image = os.path.join(semantic, f'{frame:08d}.png') if not args.np else os.path.join(semantic, f'{frame:08d}.npy')
additional = d[:3]
junction = d[4]
final_data.append((idx, image, depth_image, semantic_image, additional, junction))
with open(args.out_file, 'wb') as f:
pickle.dump(final_data, f)
上述代码的主要目的是将来自不同文件夹的数据整合到一个数据集文件中,以便后续的训练或分析。我们可以通过命令行参数来指定要处理的文件夹列表、类别索引和输出文件名。