最近跑通了densefusion代码,最近在看代码实现,dataset.py是数据预处理部分,做了一些注释,交流学习之用。
dataset.py
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import torch
import json
import codecs
import numpy as np
import sys
import torchvision.transforms as transforms
import argparse
import json
import time
import random
import numpy.ma as ma
import copy
import scipy.misc
import scipy.io as scio
import yaml
import cv2
class PoseDataset(data.Dataset):
def __init__(self, mode, num, add_noise, root, noise_trans, refine):
"""
:param mode: 可以选择train,test,eval
:param num: mesh点的数目
:param add_noise:是否加入噪声
:param root:数据集的根目录
:param noise_trans:噪声增强的相关参数
:param refine:是否需要为refine模型提供相应的数据
"""
# 这里表示目标物体类别序列号
self.objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
# 可以选择train,test,eval
self.mode = mode
# 存储RGB图像的路径
self.list_rgb = []
# 存储深度图像的路径
self.list_depth = []
# 存储语义分割出来物体对应的mask
self.list_label = []
# 两个拼接起来,可以知道每张图片的路径,及物体类别,和图片下标
self.list_obj = []
self.list_rank = []
# 矩阵信息,拍摄图片时的旋转矩阵和偏移矩阵,以及物体box
self.meta = {}
# 保存目标模型models点云数据,及models/obj_xx.ply文件中的数据
self.pt = {}
# 数据的所在的目录
self.root = root
# 噪声相关参数
self.noise_trans = noise_trans
self.refine = refine
item_count = 0
# 对每个目标物体的相关数据都进行处理
for item in self.objlist:
# 根据训练或者测试获得相应文件中的txt内容,其中保存的都是图片对应的名称数目
if self.mode == 'train':
input_file = open('{0}/data/{1}/train.txt'.format(self.root, '%02d' % item))
else:
input_file = open('{0}/data/{1}/test.txt'.format(self.root, '%02d' % item))
# 循环txt记录的每张图片进行处理
while 1:
item_count += 1
input_line = input_file.readline() # 从文件读取每一整行
# test模式下,图片序列为10的倍数则continue
if self.mode == 'test' and item_count % 10 != 0:
continue
# 文件读取完成
if not input_line:
break
if input_line[-1:] == '\n': # 提取最后一个元素
input_line = input_line[:-1] # 提取除了最后一个元素的全部元素
# 把RGB图像、depth图像的路径加载到列表中
self.list_rgb.append('{0}/data/{1}/rgb/{2}.png'.format(self.root, '%02d' % item, input_line))
self.list_depth.append('{0}/data/{1}/depth/{2}.png'.format(self.root, '%02d' % item, input_line))
# 如果是评估模式,则添加segnet_results图片的mask,否则添加data中的mask图片(该为标准mask)
# 大家可以想想,训练的时候肯定使用最标准的mask,但是在测试的时候,是要结合实际了,所以使用的
# 是通过分割网络分割出来的mask,即则添加segnet_results中的图片
if self.mode == 'eval':
self.list_label.append('{0}/segnet_results/{1}_label/{2}_label.png'.format(self.root, '%02d' % item, input_line))
else:
self.list_label.append('{0}/data/{1}/mask/{2}.png'.format(self.root, '%02d' % item, input_line))
# 把物体的下标,和txt读取到的图片数目标记分别添加到list_obj,list_rank中
self.list_obj.append(item)
self.list_rank.append(int(input_line))
# gt.yml主要保存的,是拍摄图片时,物体的旋转矩阵以及偏移矩阵,以及物体标签的box
# 有了该参数,我们就能把对应的图片,从2维空间恢复到3维空间了
meta_file = open('{0}/data/{1}/gt.yml'.format(self.root, '%02d' % item), 'r')
self.meta[item] = yaml.load(meta_file)
# 这里保存的是目标物体,拍摄物体第一帧的点云数据,可以成为模型数据
self.pt[item] = ply_vtx('{0}/models/obj_{1}.ply'.format(self.root, '%02d' % item))
print("Object {0} buffer loaded".format(item))
# 读取所有图片的路径,以及文件信息之后,打印图片数目
self.length = len(self.list_rgb)
# 摄像头的中中心坐标
self.cam_cx = 325.26110
self.cam_cy = 242.04899
# 摄像头的x,y轴的长度
self.cam_fx = 572.41140
self.cam_fy = 573.57043
# 列举处图片的x和y坐标
# xmap每一个矩阵元素位置都是代表x坐标,ymap每一元素处代表y坐标
self.xmap = np.array([[j for i in range(640)] for j in range(480)]) # 列表推导式。输出一个480*640的列表
self.ymap = np.array([[i for i in range(640)] for j in range(480)])
# 设定获取目标物体点云的数据
self.num = num
self.add_noise = add_noise
# 数据处理,https://dongfangyou.blog.csdn.net/article/details/108022357
self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05) # 修改修改亮度、对比度、饱和度和色相参数
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 chw
# 边界列表,可以想象把一个图片切割成了多个坐标
self.border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680]
# 点云的最大和最小数目
self.num_pt_mesh_large = 500
self.num_pt_mesh_small = 500
# 标记对称物体的序列号
self.symmetry_obj_idx = [7, 8]
# 迭代的时候,会调用该函数
def __getitem__(self, index):
# 根据索引获得对应图像的RGB像素
img = Image.open(self.list_rgb[index])
ori_img = np.array(img)
# 根据索引获得对应图像的深度图像素
depth = np.array(Image.open(self.list_depth[index]))
# 根据索引获得对应图像的mask像素
label = np.array(Image.open(self.list_label[index]))
# 获得物体属于的类别的序列号
obj = self.list_obj[index]
# 获得该张图片物体图像的标号
rank = self.list_rank[index]
# 如果该目标物体的序列为2,暂时不对序列为2的物体图像进行处理
if obj == 2:
# 对该物体的每个图片下标进行循环
for i in range(0, len(self.meta[obj][rank])):
# 验证该图片目标是否为2,如果是则赋值给meta
if self.meta[obj][rank][i]['obj_id'] == 2:
meta = self.meta[obj][rank][i]
break
else:
meta = self.meta[obj][rank][0] # 把对应类和标号的R,t,obj_bb,obj_id取出来
# 只要像素不为0的,返回值为Ture,如果为0的像素,返回False。masked_not_equal将不等于0的数据掩码, getmaskarray返回掩码数组的掩码的完整布尔数组
mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))
# 标准的数据中的mask是3通道的,通过网络分割出来的mask,其是单通道的
if self.mode == 'eval':
mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255)))
else:
mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0] # 取三维矩阵中第一维的所有数据
# 把mask_label和深度图结合到到一起,物体存在的区域像素为True,背景像素为False
mask = mask_label * mask_depth
# 对图像加入噪声
if self.add_noise:
img = self.trancolor(img)
# [b,h,w,c] --> [b,c,h,w]
img = np.array(img)[:, :, :3] # 图像转换成数组
# print(img.shape) # (480,640,3)
img = np.transpose(img, (2, 0, 1))
img_masked = img
# print(img.shape) # (3,480,640)
# 如果为eval模式,根据mask_label获得目标的box(rmin-rmax表示行的位置,cmin-cmax表示列的位置)
if self.mode == 'eval':
# mask_to_bbox表示的根据mask生成合适的box
rmin, rmax, cmin, cmax = get_bbox(mask_to_bbox(mask_label))
else:
# 如果不是eval模式,则从gt.yml文件中获取最标准的box
rmin, rmax, cmin, cmax = get_bbox(meta['obj_bb'])
# 根据确定的行和列,对图像进行截取,就是截取处包含了目标物体的图像
img_masked = img_masked[:, rmin:rmax, cmin:cmax]
#p_img = np.transpose(img_masked, (1, 2, 0))
#scipy.misc.imsave('evaluation_result/{0}_input.png'.format(index), p_img)
# 获得目标物体旋转矩阵的参数,以及偏移矩阵参数
target_r = np.resize(np.array(meta['cam_R_m2c']), (3, 3))
target_t = np.array(meta['cam_t_m2c'])
# 对偏移矩阵添加噪声
add_t = np.array([random.uniform(-self.noise_trans, self.noise_trans) for i in range(3)])
# 对mask图片的目标部分进行剪裁,变成拉成一维,返回数组中值不为零的元素的标号
choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]
# 如果剪切下来的部分面积为0,则直接返回五个0的元组,即表示没有目标物体
if len(choose) == 0:
cc = torch.LongTensor([0])
return(cc, cc, cc, cc, cc, cc)
# 如果剪切下来目标图片的像素,大于点云的数目(一般都是这种情况)
if len(choose) > self.num:
c_mask = np.zeros(len(choose), dtype=int) # c_mask全部设置为0,大小和choose相同
c_mask[:self.num] = 1 # 前self.num设置为
np.random.shuffle(c_mask) # 随机打乱
choose = choose[c_mask.nonzero()] # 选择c_mask不是0的部分,也就是说只选择了500个像素,注意nonzero()返回的是索引
# 如果剪切像素点的数目,小于了点云的数目
else:
choose = np.pad(choose, (0, self.num - len(choose)), 'wrap') # 这使用0填补,调整到和点云数目一样大小,500
# 把深度图,对应着物体的部分也剪切下来,然后拉平变成一维,挑选坐标
depth_masked = depth[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
# 把物体存在于原图的位置坐标剪切下来,拉平然后进行挑选坐标
xmap_masked = self.xmap[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
ymap_masked = self.ymap[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
choose = np.array([choose]) # (500,)
# 根据相机内参和像素坐标求解点云的三维坐标
cam_scale = 1.0 # 摄像头缩放参数
pt2 = depth_masked / cam_scale # z值(depth) (500,1)
pt0 = (ymap_masked - self.cam_cx) * pt2 / self.cam_fx # y值(500,1)
pt1 = (xmap_masked - self.cam_cy) * pt2 / self.cam_fy # x值 (500,1)
cloud = np.concatenate((pt0, pt1, pt2), axis=1) # 把y,x,depth3个坐标合并成点云数据,(500,3)
cloud = cloud / 1000.0 # 这里把点云数据除以1000,是为了根据深度进行正则化
# 对点云添加噪声
if self.add_noise:
cloud = np.add(cloud, add_t)
# 存储在obj_xx.ply中的点云数据,对其进行正则化,也就是目标物体的点云信息
model_points = self.pt[obj] / 1000.0
dellist = [j for j in range(0, len(model_points))] # 复制model_points列表,列表解析式
dellist = random.sample(dellist, len(model_points) - self.num_pt_mesh_small) # 截取dellist列表的指定长度的随机数,但是不会改变列表本身的排序
model_points = np.delete(model_points, dellist, axis=0) # 随机删除多余的点云数据,训练时只需要num_pt_mesh_small数目的点云
# 根据model_points(第一帧目标模型对应的点云信息)以及target(目前迭代这张图片)的旋转和偏移矩阵,计算出对应的点云数据
target = np.dot(model_points, target_r.T)
if self.add_noise:
target = np.add(target, target_t / 1000.0 + add_t)
out_t = target_t / 1000.0 + add_t
else:
target = np.add(target, target_t / 1000.0)
out_t = target_t / 1000.0
return torch.from_numpy(cloud.astype(np.float32)), \
torch.LongTensor(choose.astype(np.int32)), \
self.norm(torch.from_numpy(img_masked.astype(np.float32))), \
torch.from_numpy(target.astype(np.float32)), \
torch.from_numpy(model_points.astype(np.float32)), \
torch.LongTensor([self.objlist.index(obj)])
# 总结:
# cloud:由深度图计算出来的目标区域点云,该点云数据以本摄像头为参考坐标
# choose:所选择点云的列索引
# img_masked:data/rgb的图像通过box剪切下来的目标区域RGB图像
# target:根据model_points(第一帧目标模型对应的点云信息)以及target(目前迭代这张图片)的旋转和偏移矩阵,计算出对应的点云数据,500个点
# model_points:目标初始帧(模型)对应的点云预处理后的标准点云信息,500个点
# [self.objlist.index(obj)]:目标物体的序列编号
def __len__(self):
return self.length
def get_sym_list(self):
return self.symmetry_obj_idx
def get_num_points_mesh(self):
if self.refine:
return self.num_pt_mesh_large
else:
return self.num_pt_mesh_small
border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680]
img_width = 480
img_length = 640
def mask_to_bbox(mask):
mask = mask.astype(np.uint8)
_, contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
x = 0
y = 0
w = 0
h = 0
for contour in contours:
tmp_x, tmp_y, tmp_w, tmp_h = cv2.boundingRect(contour)
if tmp_w * tmp_h > w * h:
x = tmp_x
y = tmp_y
w = tmp_w
h = tmp_h
return [x, y, w, h]
def get_bbox(bbox):
bbx = [bbox[1], bbox[1] + bbox[3], bbox[0], bbox[0] + bbox[2]]
if bbx[0] < 0:
bbx[0] = 0
if bbx[1] >= 480:
bbx[1] = 479
if bbx[2] < 0:
bbx[2] = 0
if bbx[3] >= 640:
bbx[3] = 639
rmin, rmax, cmin, cmax = bbx[0], bbx[1], bbx[2], bbx[3]
r_b = rmax - rmin
for tt in range(len(border_list)):
if r_b > border_list[tt] and r_b < border_list[tt + 1]:
r_b = border_list[tt + 1]
break
c_b = cmax - cmin
for tt in range(len(border_list)):
if c_b > border_list[tt] and c_b < border_list[tt + 1]:
c_b = border_list[tt + 1]
break
center = [int((rmin + rmax) / 2), int((cmin + cmax) / 2)]
rmin = center[0] - int(r_b / 2)
rmax = center[0] + int(r_b / 2)
cmin = center[1] - int(c_b / 2)
cmax = center[1] + int(c_b / 2)
if rmin < 0:
delt = -rmin
rmin = 0
rmax += delt
if cmin < 0:
delt = -cmin
cmin = 0
cmax += delt
if rmax > 480:
delt = rmax - 480
rmax = 480
rmin -= delt
if cmax > 640:
delt = cmax - 640
cmax = 640
cmin -= delt
return rmin, rmax, cmin, cmax
def ply_vtx(path):
f = open(path)
assert f.readline().strip() == "ply"
f.readline()
f.readline()
N = int(f.readline().split()[-1])
while f.readline().strip() != "end_header":
continue
pts = []
for _ in range(N):
pts.append(np.float32(f.readline().split()[:3]))
return np.array(pts)