class Config
self.dataset_name 数据集名称
self.exp_dir 当前文件路径
self.exp_name 当前文件名
self.resnet_ptr_mdl_p RESNET预训练模型
self.log_dir train_log
self.log_model_dir checkpoints
self.log_eval_dir eval_results
self.n_total_epoch = 25
self.mini_batch_size = 2
self.num_mini_batch_per_epoch = 105 * 4
self.val_mini_batch_size = 2
self.val_num_mini_batch_per_epoch = 43 * 4
self.test_mini_batch_size = 1
self.n_sample_points = 8192 + 4096
self.n_keypoints = 20
self.n_min_points = 400
self.noise_trans = 0.05
self.n_objects = 1 + 1
self.n_classes = 1 + 1
self.pallet_cls_lst_p 数据集的类别定义文件
self.pallet_root 数据集根目录
self.pallet_kps_dir 数据集的21个物体的角点,最远点三维坐标,是物体的信息
datasets/pallet/Pallet_obj_kps/
pallet_r_lst_p 21个物体的半径
datasets/pallet/dataset_config/radius.txt
self.pallet_r_lst np.loadtxt(pallet_r_lst_p).tolist()
21个物体的半径列表
self.pallet_cls_lst self.read_lines(self.pallet_cls_lst_p)
21个物体的类别列表
self.val_test_pkl_p = os.path.join(
self.exp_dir,
'datasets/pallet/test_val_data_pts{}.pkl'.format(self.n_sample_points),
)
self.intrinsic_matrix
class Pallet_Dataset
self.dataset_name = dataset_name
self.xmap = np.array([[j for i in range(640)] for j in range(480)])
self.ymap = np.array([[i for i in range(640)] for j in range(480)])
self.diameters = {} 直径
self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05) 随机改变图像的亮度、对比度和饱和度。
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.224])
self.cls_lst = bs_utils.read_lines(config.pallet_cls_lst_p)
self.obj_dict= {} 物体字典,键是物体名字,值是索引1到21
for cls_id, cls in enumerate(self.cls_lst, start=1):
self.obj_dict[cls] = cls_id
self.path 训练集数据列表文件路径
self.all_lst 训练集中的真实数据
self.real_lst = [] 训练集中的真实数据,存的东西是data/0048/000135类型
for item in self.all_lst:
self.real_lst.append(item)
self.pp_data = None
if os.path.exists(config.preprocessed_testset_pth) and config.use_preprocess:
print('Loading valtestset.')
with open(config.preprocessed_testset_pth, 'rb') as f:
self.pp_data = pkl.load(f)
self.all_lst = [i for i in range(len(self.pp_data))]
print('Finish loading valtestset.')
else:
self.add_noise = False
root_path = os.path.split(os.path.realpath(__file__))[0]
self.path = root_path + '/dataset_config/test_data_list.txt'
self.all_lst = bs_utils.read_lines(self.path)
print("{}_dataset_size: ".format(dataset_name), len(self.all_lst))
self.root = config.pallet_root
self.current_item_name = ""
def real_syn_gen(self): 随机读取一个训练集中的元素
def real_gen(self): 从real_lst中随机读一个
def rand_range(self, rng, lo, hi): 在hi和lo中随机生成一个值返回
def gaussian_noise(self, rng, img, sigma): 将给定 sigma 的高斯噪声添加到图像
def linear_motion_blur(self, img, angle, length): 线性运动模糊
def rgb_add_noise(self, img): 通过self.rng产生随机数,来决定对图像进行何种增强,锐化,运动模糊,高斯模糊
def get_normal(self, cld): 计算点云的法向量
def add_real_back(self, rgb, labels, dpt, dpt_msk):
def get_item(self, item_name): 读取了mat数据
def get_item_in_use(self, item_name): 仅在推理,无标注信息时使用!!!