frustum pointnets训练代码学习笔记——kitti_object.py

frustum pointnets训练代码学习笔记——kitti_object.py

本文记录了博主学习frustum pointnets过程中遇到的2D和3D数据库显示程序。为了画出输出结果,博主希望在这个程序的基础上修改一个可以显示结果的程序。更新于2018.09.22。

本文首先给出代码原文的学习笔记,随后整理出修改后的结果显示程序,如果公开,会在这里放上链接,如果有帮助,请在代码页面点一下小星星哦。附可能有用的信息:各个集合所用的文件名在kitti/image_sets文件夹下。

总结

把总结写在前面,根据需要判断是否需要详细看源码分析。

  • 这个文件的主要功能就是将KITTI库中的2d和3d结果画出来(至于是画training还是testing,文件的kitti_object函数的初始函数和get_label_objects分别有定义和判断)。
  • 文件个数是人为在kitti_object函数中设定的,并非自动提取。
  • 画图是机械第从第一个图片一直向后显示,且如果该图片中有多个目标,仅显示txt文件中排在第一的那个的数据。通过修改objects[0].print_object()中[]里面的标号可以指定画第几个目标,不过要注意的是,每个图片中含有的目标个数是不同的。

用到的语法规则

这一部分记录了代码原文中出现的语法规则,并不影响代码功能的理解,但是可能方便日后的使用,因此在这里记录下来。

from __future__ import print_function

加上这句话以后,即使在python2.X也要像python3.X一样的语法使用print函数(加括号)。类似地,如果有其他新的功能特性且该特性与当前版本中的使用不兼容,就可以从future模块导入。详细说明参考这里

from PIL import Image

PIL已经是python平台事实上的图像处理标准库了,全称为Python Imaging Library。具体的使用方法说明可以参考这里

BASE_DIR = os.path.dirname(os.path.abspath(__file__))

其中,__file__就是当前所执行的文件,也就是kitti_object.pyos.path.abspath命令获取的是当前文件的绝对路径,比如博主的运行结果:

>>> print os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py")
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py

而前面的os.path.dirname获取的则是当前路径所存在于的文件夹,因此,BASE_DIR指向的就是kitti_object.py所处的文件夹了。运行结果为:

>>> print os.path.dirname(os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py"))
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti

sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))

其中,os.path.join用于路径拼接。
sys.path.append:在导入一个模块时,默认情况下python会搜索当前目录、已安装的内置模块和第三方模块,搜索路径存放在sys模块的path中。如果要用的模块和当前脚本不在一个目录下,就需要将其添加到path中。这种修改是临时的,脚本运行后失效。

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

opencv中提供了cvtColor函数用于实现图像格式类型的相互转换,具体说明可以参照这里


代码原文分析

#代码作者信息
''' Helper class and functions for loading KITTI objects

Author: Charles R. Qi
Date: September 2017
'''

#加载必要的库
from __future__ import print_function			

import os
import sys
import numpy as np
import cv2
from PIL import Image
#定义基础路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))			#指向当前文件所在文件夹(kitti)
ROOT_DIR = os.path.dirname(BASE_DIR)			#指向frustum文件夹
sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))
import kitti_util as utils			#加载论文作者写的库

try:
	raw_input          # Python 2
except NameError:
   	raw_input = input  # Python 3


#用于获取数据库各项路径路径(training或testing)
class kitti_object(object):
	'''Load and parse object data into a usable format.'''

	def __init__(self, root_dir, split='training'):
    	'''root_dir contains training and testing folders'''
    	self.root_dir = root_dir
    	self.split = split
    	self.split_dir = os.path.join(root_dir, split)

    	if split == 'training':
        	self.num_samples = 7481
    	elif split == 'testing':
        	self.num_samples = 7518
    	else:
        	print('Unknown split: %s' % (split))
        	exit(-1)

    	self.image_dir = os.path.join(self.split_dir, 'image_2')
    	self.calib_dir = os.path.join(self.split_dir, 'calib')
    	self.lidar_dir = os.path.join(self.split_dir, 'velodyne')
    	self.label_dir = os.path.join(self.split_dir, 'label_2')

	# 用于后面获取样本库内的样本总数
	def __len__(self):
    	return self.num_samples

	def get_image(self, idx):
    	assert(idx<self.num_samples) 
    	img_filename = os.path.join(self.image_dir, '%06d.png'%(idx))
    	return utils.load_image(img_filename)

	def get_lidar(self, idx): 
    	assert(idx<self.num_samples) 
    	lidar_filename = os.path.join(self.lidar_dir, '%06d.bin'%(idx))
    	return utils.load_velo_scan(lidar_filename)

	def get_calibration(self, idx):
    	assert(idx<self.num_samples) 
    	calib_filename = os.path.join(self.calib_dir, '%06d.txt'%(idx))
    	return utils.Calibration(calib_filename)

	# 获取idx指示的样本对应的label文件路径,并按行读取文件,返回文件内容
	def get_label_objects(self, idx):
    	assert(idx<self.num_samples and self.split=='training') 
    	label_filename = os.path.join(self.label_dir, '%06d.txt'%(idx))
    	return utils.read_label(label_filename)
    
	def get_depth_map(self, idx):
    	pass

	def get_top_down(self, idx):
    	pass

class kitti_object_video(object):
    ''' Load data for KITTI videos '''
	def __init__(self, img_dir, lidar_dir, calib_dir):
    	self.calib = utils.Calibration(calib_dir, from_video=True)
    	self.img_dir = img_dir
    	self.lidar_dir = lidar_dir
    	self.img_filenames = sorted([os.path.join(img_dir, filename) \
        	for filename in os.listdir(img_dir)])
   		self.lidar_filenames = sorted([os.path.join(lidar_dir, filename) \
        	for filename in os.listdir(lidar_dir)])
    	print(len(self.img_filenames))
    	print(len(self.lidar_filenames))
    	#assert(len(self.img_filenames) == len(self.lidar_filenames))
    	self.num_samples = len(self.img_filenames)

	def __len__(self):
    	return self.num_samples

	def get_image(self, idx):
    	assert(idx<self.num_samples) 
    	img_filename = self.img_filenames[idx]
    	return utils.load_image(img_filename)

	def get_lidar(self, idx): 
    	assert(idx<self.num_samples) 
    	lidar_filename = self.lidar_filenames[idx]
    	return utils.load_velo_scan(lidar_filename)

	def get_calibration(self, unused):
    	return self.calib

def viz_kitti_video():
	video_path = os.path.join(ROOT_DIR, 'dataset/2011_09_26/')
	dataset = kitti_object_video(\
    	os.path.join(video_path, '2011_09_26_drive_0023_sync/image_02/data'),
    	os.path.join(video_path, '2011_09_26_drive_0023_sync/velodyne_points/data'),
    	video_path)
	print(len(dataset))
	for i in range(len(dataset)):
    	img = dataset.get_image(0)
    	pc = dataset.get_lidar(0)
    	Image.fromarray(img).show()
    	draw_lidar(pc)
    	raw_input()
    	pc[:,0:3] = dataset.get_calibration().project_velo_to_rect(pc[:,0:3])
    	draw_lidar(pc)
    	raw_input()
	return

def show_image_with_boxes(img, objects, calib, show3d=True):
	''' Show image with 2D bounding boxes '''
	img1 = np.copy(img) # for 2d bbox
	img2 = np.copy(img) # for 3d bbox
	for obj in objects:
    	if obj.type=='DontCare':continue
    	cv2.rectangle(img1, (int(obj.xmin),int(obj.ymin)),
        	(int(obj.xmax),int(obj.ymax)), (0,255,0), 2)
    	box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P)
    	img2 = utils.draw_projected_box3d(img2, box3d_pts_2d)
	Image.fromarray(img1).show()
	if show3d:
    	Image.fromarray(img2).show()

def get_lidar_in_image_fov(pc_velo, calib, xmin, ymin, xmax, ymax,
                       return_more=False, clip_distance=2.0):
	''' Filter lidar points, keep those in image FOV '''
	pts_2d = calib.project_velo_to_image(pc_velo)
	fov_inds = (pts_2d[:,0]<xmax) & (pts_2d[:,0]>=xmin) & \
    	(pts_2d[:,1]<ymax) & (pts_2d[:,1]>=ymin)
	fov_inds = fov_inds & (pc_velo[:,0]>clip_distance)
	imgfov_pc_velo = pc_velo[fov_inds,:]
	if return_more:
    	return imgfov_pc_velo, pts_2d, fov_inds
	else:
    	return imgfov_pc_velo

def show_lidar_with_boxes(pc_velo, objects, calib,
                      img_fov=False, img_width=None, img_height=None): 
	''' Show all LiDAR points.
    	Draw 3d box in LiDAR point cloud (in velo coord system) '''
	if 'mlab' not in sys.modules: import mayavi.mlab as mlab
	from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d

	print(('All point num: ', pc_velo.shape[0]))
	fig = mlab.figure(figure=None, bgcolor=(0,0,0),
    	fgcolor=None, engine=None, size=(1000, 500))
	if img_fov:
    	pc_velo = get_lidar_in_image_fov(pc_velo, calib, 0, 0,
        	img_width, img_height)
    	print(('FOV point num: ', pc_velo.shape[0]))
	draw_lidar(pc_velo, fig=fig)

	for obj in objects:
    	if obj.type=='DontCare':continue
    	# Draw 3d bounding box
    	box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P) 
    	box3d_pts_3d_velo = calib.project_rect_to_velo(box3d_pts_3d)
    	# Draw heading arrow
    	ori3d_pts_2d, ori3d_pts_3d = utils.compute_orientation_3d(obj, calib.P)
    	ori3d_pts_3d_velo = calib.project_rect_to_velo(ori3d_pts_3d)
    	x1,y1,z1 = ori3d_pts_3d_velo[0,:]
    	x2,y2,z2 = ori3d_pts_3d_velo[1,:]
    	draw_gt_boxes3d([box3d_pts_3d_velo], fig=fig)
   	 	mlab.plot3d([x1, x2], [y1, y2], [z1,z2], color=(0.5,0.5,0.5),
        	tube_radius=None, line_width=1, figure=fig)
	mlab.show(1)

def show_lidar_on_image(pc_velo, img, calib, img_width, img_height):
	''' Project LiDAR points to image '''
	imgfov_pc_velo, pts_2d, fov_inds = get_lidar_in_image_fov(pc_velo,
    	calib, 0, 0, img_width, img_height, True)
	imgfov_pts_2d = pts_2d[fov_inds,:]
	imgfov_pc_rect = calib.project_velo_to_rect(imgfov_pc_velo)

	import matplotlib.pyplot as plt
	cmap = plt.cm.get_cmap('hsv', 256)
	cmap = np.array([cmap(i) for i in range(256)])[:,:3]*255

	for i in range(imgfov_pts_2d.shape[0]):
    	depth = imgfov_pc_rect[i,2]
    	color = cmap[int(640.0/depth),:]
    	cv2.circle(img, (int(np.round(imgfov_pts_2d[i,0])),
        	int(np.round(imgfov_pts_2d[i,1]))),
        	2, color=tuple(color), thickness=-1)
	Image.fromarray(img).show() 
	return img

def dataset_viz():
	dataset = kitti_object(os.path.join(ROOT_DIR, 'dataset/KITTI/object'))			#获取数据库各项路径

	for data_idx in range(len(dataset)):			#从0开始到len获取的数据库样本总数
    	# 从数据库中加载数据
    	objects = dataset.get_label_objects(data_idx)			#获取data_idx对应的结果
    	objects[0].print_object()			#在屏幕上输出data_idx对应的第一个结果(如果有多个,修改[]内的值就可以变成对应的结果
    	img = dataset.get_image(data_idx)			#获取data_idx对应的图片
    	img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 			#图像格式转换
    	img_height, img_width, img_channel = img.shape
    	print(('Image shape: ', img.shape))
    	pc_velo = dataset.get_lidar(data_idx)[:,0:3]			#获取data_idx对应的3D点云
    	calib = dataset.get_calibration(data_idx)

    	# 在图像上画出2d和3dboxes
    	show_image_with_boxes(img, objects, calib, False)
    	raw_input()
    	# Show all LiDAR points. Draw 3d box in LiDAR point cloud
    	show_lidar_with_boxes(pc_velo, objects, calib, True, img_width, img_height)
    	raw_input()

if __name__=='__main__':
	import mayavi.mlab as mlab
	from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d
	dataset_viz()			#显示数据
  • 4
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值