data/init.py
make_basic_dataset函数
用来构建基础数据集
输入:一些列的参数(pkl_path, train_size, val_size, pad, *, test_ext='', re_prob=0.5, with_mask=False, for_vis=False)
return:预处理后的训练集、验证集合以及元数据集
def make_basic_dataset(pkl_path, train_size, val_size, pad, *, test_ext='', re_prob=0.5, with_mask=False, for_vis=False):
"""
构建基础数据集。
"""
meta_dataset = datasets.CommonReIDDataset(pkl_path=pkl_path, test_ext=test_ext)
train_transform = demo_trans.get_training_albumentations(train_size, pad, re_prob)
val_transform = demo_trans.get_validation_augmentations(val_size)
if for_vis:
preprocessing = None
else:
preprocessing = demo_trans.get_preprocessing()
train_dataset = datasets.ReIDDataset(
meta_dataset.train, with_mask=with_mask, transform=train_transform, preprocessing=preprocessing)
val_dataset = datasets.ReIDDataset(meta_dataset.query + meta_dataset.gallery, with_mask=with_mask, transform=val_transform,
preprocessing=preprocessing)
return train_dataset, val_dataset, meta_dataset
下面是重点了,就是在main函数中的调用。之前所做的所有工作都是为了在main函数中进行调用,可见台上一分钟,台下十年功。几行简单的数据加载代码背后是多少行代码的支撑啊。
parsing_reid/main.py
加载训练集、验证集、元数据:
train_dataset, valid_dataset, meta_dataset = make_basic_dataset(cfg.data.pkl_path,
cfg.data.train_size,
cfg.data.valid_size,
cfg.data.pad,
test_ext=cfg.data.test_ext,
re_prob=cfg.data.re_prob,
with_mask=cfg.data.with_mask,
)
将加载的训练集和验证集进行Dataloader(将所有的数据集封装为一个一个batch的形式,进行分批次读取训练)
如果对Pytorch的Dataloader不理解,可以参考此博文:链接: https://blog.csdn.net/sinat_42239797/article/details/90641659.
train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=cfg.data.batch_size,
num_workers=cfg.data.train_num_workers, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=cfg.data.batch_size, num_workers=cfg.data.test_num_workers,
pin_memory=True, shuffle=False)
数据加载部分就此完结了,感谢各位的观看!