DenseFusion: 6D Object Pose Estimation by Iterative Dense Fusion代码解读
看源码的时候加了一些注释,中英文都写了:链接
数据读取
class PoseDataset(data.Dataset):
def __init__(self, mode, num, add_noise, root, noise_trans, refine):
self.objlist = [1]
# self.objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
self.mode = mode
self.list_rgb = []
self.list_depth = []
self.list_label = []
self.list_obj = []
self.list_rank = []
self.meta = {
}
self.pt = {
}
self.root = root
self.noise_trans = noise_trans
self.refine = refine
item_count = 0
for item in self.objlist:
if self.mode == 'train':
input_file = open('{0}/data/{1}/train.txt'.format(self.root, '%02d' % item))
else:
input_file = open('{0}/data/{1}/test.txt'.format(self.root, '%02d' % item))
while 1:
item_count += 1
input_line = input_file.readline()
if self.mode == 'test' and item_count % 10 != 0:
continue
if not input_line:
break
if input_line[-1:] == '\n':
input_line = input_line[:-1]
self.list_rgb.append('{0}/data/{1}/rgb/{2}.png'.format(self.root, '%02d' % item, input_line))
self.list_depth.append('{0}/data/{1}/depth/{2}.png'.format(self.root, '%02d' % item, input_line))
if self.mode == 'eval':
self.list_label.append('{0}/segnet_results/{1}_label/{2}_label.png'.format(self.root, '%02d' % item, input_line))
else:
self.list_label.append('{0}/data/{1}/mask/{2}.png'.format(self.root, '%02d' % item, input_line))
# use index to get object class
# 使用index时能获取物体类别
self.list_obj.append(item)
# picture name, such 0000 0001
# can use index to get name of the item(use list_obj to get class)
# 存储图片的名字,同样使用index就能得到,配合着list_obj,就能知道详细信息
self.list_rank.append(int(input_line))
# ground truth R T bbox class
# R,T bbox,类别的真实值
meta_file = open('{0}/data/{1}/gt.yml'.format(self.root, '%02d' % item), 'r')
# with list_rank get item ground truth infor
# 配合着list_rank,就能知道具体信息
self.meta[item] = yaml.load(meta_file)
# ply infor
# 3D模型点云文件
self.pt[item] = ply_vtx('{0}/models/obj_{1}.ply'.format(self.root, '%02d' % item))
print("Object {0} buffer loaded".format(item))
self.length = len(self.list_rgb)
self.cam_cx = 325.26110
self.cam_cy = 242.04899
self.cam_fx = 572.41140
self.cam_fy = 573.57043
# 做一个网格,用来提取mask的坐标
# row is same, column from 0 to 479
self.xmap = np.array([[j for i in range(640)] for j in range(480)])
# row from 0 to 639, column is same
self.ymap = np.array([[i for i in range(640)] for j in range(480)])
self.num = num
self.add_noise = add_noise
# torchvision自动增强
self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# 这个不懂
self.border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680]
self.num_pt_mesh_large = 500
self.num_pt_mesh_small = 500
self.symmetry_obj_idx = [7, 8]
def __getitem__(self, index):
img = Image.open(self.list_rgb[index])
img.save('/home/ouc/TXH/DenseFusion/visual_refine/ori.jpg')
ori_img = np.array(img)
depth = np.array(Image.open(self.list_depth[index]))
label = np.array(Image.open(self.list_label[index]))
obj = self.list_obj[index]
rank = self.list_rank[index]
if obj == 2:
for i in range(0, len(self.meta[obj][rank])):
if self.meta[obj][rank][i]['obj_id'] == 2:
meta = self.meta[obj][rank][i]
break
else:
meta = self.meta[obj][rank][0]
mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))
if self.mode == 'eval':
mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255)))
else:
mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0]
# 看来255是被选定的mask,就是白色
# 只有True * True才是True
# for what? for choose
mask = mask_label * mask_depth
if self.add_noise:
img = self.trancolor(img)
img = np.array(img)[:, :, :3]
# channel in first
img = np.transpose(img, (2, 0, 1))
img_masked = img
# get bbox是设置边界的函数?
# 我觉得可能是把物体抠出来,但是borde_list有什么用我不清楚
# NUS, about border_list?
if self.mode == 'eval':
rmin, rmax, cmin, cmax = get_bbox(mask_to_bbox(mask_label))
else:
rmin, rmax, cmin, cmax = get_bbox(meta['obj_bb'])
img_masked = img_masked[:, rmin:rmax, cmin:cmax]
#p_img = np.transpose(img_masked, (1, 2, 0))
#scipy.misc.imsave('evaluation_result/{0}_input.png'.format(index), p_img)
target_r = np.resize(np.array(meta['cam_R_m2c']), (3, 3))
target_t = np.array(meta['cam_t_m2c'])
add_t = np.array([random.uniform(-self.noise_trans, self.noise_trans) for i in range(3)])
# 只要限定范围内的物体
# get infor in bbox
choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]
if len(choose) == 0:
# 把narray转换成Long类型