数据加载部分–base.py
PVEN车辆重识别
论文地址:https://arxiv.org/abs/2004.05021
Pytorch代码地址:https://github.com/silverbulletmdc/PVEN
get_imagedata_info函数
输入:data
return:num_pids(车辆id数目), num_imgs(img数量), num_cams(cam数目)
def get_imagedata_info(data):#获取图像信息
ids, cams = [], []#ids和cam分别存放车辆的id和cam信息
for item in data:
ids.append(item["id"])
cams.append(item["cam"])
ids = [item["id"] for item in data]
cams = [item["cam"] for item in data]
pids = set(ids)#set集合可用于去重复
cams = set(cams)
num_pids = len(pids)#车辆id数目
num_cams = len(cams)#cam数目
num_imgs = len(data)#图像数目
return num_pids, num_imgs, num_cams
relabel函数
输入:data
return:data(更新后的data), rawid2label(车辆id对应的序号字典), label2rawid(序号对应的车辆id字典)
def relabel(data):
"""
:param list data:
:return:
"""
raw_ids = set()
data = data.copy()
for item in data:
raw_ids.add(item['id'])
raw_ids = sorted(list(raw_ids))#将所有的id进行排序
rawid2label = {raw_vid: i for i, raw_vid in enumerate(raw_ids)}#一个id对应一个序号
label2rawid = {i: raw_vid for i, raw_vid in enumerate(raw_ids)}#一个序号对应一个id
for item in data:
item["id"] = rawid2label[item["id"]]#将data中的车辆id换为对应的序号
item["cam"] = int(item["cam"])#将data中的cam转换为整型
return data, rawid2label, label2rawid
ReIDMetaDataset类
类的属性:
self.train = metas["train"]#训练集标签
self.query = metas["query"]#查询集标签
self.gallery = metas["gallery"]#gallery集合标签
self.relabel()#标签更新
self._calc_meta_info()#统计各个集合信息
if verbose:#用于输出数据集的统计信息
print("=> Dataset loaded")
self.print_dataset_statistics()#输出数据集的统计信息
class ReIDMetaDataset:
"""
定义了ReID数据集的元信息。必须包含train, query, gallery属性。
A list of dict. Dict contains meta infomation, which is
{
"image_path": str, required
"id": int, required
"cam"(optional): int,
"keypoints"(optional): extra information
"kp_vis"(optional): 每个keypoint是否可见
"mask"(optional): extra information
"box"(optional): extra information
"color"(optional): extra information
"type"(optional): extra information
"view"(optional): extra information
}
"""
def __init__(self, pkl_path, verbose=True, **kwargs):
with open(pkl_path, 'rb') as f:
metas = pkl.load(f)
self.train = metas["train"]
self.query = metas["query"]
self.gallery = metas["gallery"]
self.relabel()
self._calc_meta_info()
if verbose:#用于输出数据集的统计信息
print("=> Dataset loaded")
self.print_dataset_statistics()
def relabel(self):#将所有的数据进行relabel
self.train, self.train_rawid2label, self.train_label2rawid = relabel(self.train)#训练集中的id和cam进行更新
eval_set, self.eval_rawid2label, self.eval_label2rawid = relabel(self.query + self.gallery)#测试集中的id和cam进行更新
self.query = eval_set[:len(self.query)]#query部分
self.gallery = eval_set[len(self.query):]#gallery部分
def print_dataset_statistics(self):#输出数据集信息
num_train_pids, num_train_imgs, num_train_cams = get_imagedata_info(self.train)#获取训练集的信息
num_query_pids, num_query_imgs, num_query_cams = get_imagedata_info(self.query)#获取查询集的信息
num_gallery_pids, num_gallery_imgs, num_gallery_cams = get_imagedata_info(self.gallery)##获取gallery数据集的信息
#输出信息
print("Dataset statistics:")
print(" ----------------------------------------")
print(" subset | # ids | # images | # cameras")
print(" ----------------------------------------")
print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
print(" ----------------------------------------")
def _calc_meta_info(self):#统计元信息
self.num_train_ids, self.num_train_imgs, self.num_train_cams = get_imagedata_info(self.train)
self.num_query_ids, self.num_query_imgs, self.num_query_cams = get_imagedata_info(self.query)
self.num_gallery_ids, self.num_gallery_imgs, self.num_gallery_cams = get_imagedata_info(self.gallery)
ReIDDataset类
继承Dataset类,用来加载数据,需要定义__init__,__getitem__和__len__方法
class ReIDDataset(Dataset):
def __init__(self, meta_dataset, *, with_mask=False, mask_num=5, transform=None, preprocessing=None):
"""将元数据集转化为图片数据集,并进行预处理
Arguments:
Dataset {ReIDMetaDataset} -- self
meta_dataset {ReIDMetaDataset} -- 元数据集
Keyword Arguments:
with_box {bool} -- [是否使用检测框做crop。从box属性中读取检测框信息] (default: {False})
with_mask {bool} -- [是否读取mask。为True时从mask_nori_id读取mask] (default: {False})
mask_num {int} -- [mask数量] (default: {5})
sub_bg {bool} -- [是否删除背景。with_mask为True时才会生效。将利用第一个mask对图片做背景减除] (default: {False})
transform {[type]} -- [数据增强] (default: {None})
preprocessing {[type]} -- [normalize, to tensor等预处理] (default: {None})
"""
self.meta_dataset = meta_dataset
self.transform = transform
self.preprocessing = preprocessing
self.with_mask = with_mask
self.mask_num = mask_num
def read_mask(self, sample):
# 读入mask
mask = cv2.imread(sample["mask_path"], cv2.IMREAD_GRAYSCALE)#转换为灰度图
mask = [mask == v for v in range(self.mask_num)]
mask = np.stack(mask, axis=-1).astype('float32')
sample["mask"] = mask
def __getitem__(self, item):
meta: dict = self.meta_dataset[item]
sample = meta.copy()
# 读入图片
sample["image"] = read_rgb_image(meta["image_path"])
# 读入mask
if self.with_mask:
self.read_mask(sample)
# 数据增强
if self.transform:
sample = self.transform(**sample)
# preprocessing
if self.preprocessing:
sample = self.preprocessing(**sample)
return sample
def __len__(self):
return len(self.meta_dataset)
这就是base.py文件中的主要内容