!转载请注明原文地址!——东方旅行者
更多行人重识别文章移步我的专栏:行人重识别专栏
数据管理器(dataset_manager.py)
一、数据管理器作用
该文件主要负责指定数据集路径、处理原始数据集并生成数据索引列表、返回子数据集相关参数(子集行人ID数量,子集图片数量)。因为Market1501已经划分好训练集、测试集与查询集,所以直接可以根据路径提取这三个数据集。
二、数据管理器编写思路
- 指定数据集根目录路径
- 分别指定训练集、测试集、查询集的路径
- 通过这三个子数据集路径获得子集下所有图片的地址,通过每个图片的地址就可以得到行人ID与摄像机ID等信息,根据这些信息生成一个索引列表,类型为list,列表中每个元素都是一个三元组(数据图片地址,行人ID,摄像头ID),与此同时获取子数据集相关信息,如子集行人ID数量,子集图片数量等参数
- 三个数据集索引列表生成完毕,打印相关参数到控制台
索引列表如下所示:
[
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000451_03.jpg',0,0),
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000551_01.jpg',0,0),
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000776_01.jpg',0,0)
]
因为每一个子集中行人ID不一定连续,所以为了便于训练,一般要对训练集的行人ID进行重排,便于训练。所以需要使用一个名称为pid2label的Map来记录原始ID与重排ID的对应关系。
三、代码
import os
import os.path as osp
import numpy as np
import glob
import re
from IPython import embed
"""
Market1501类用于
1.指定数据集路径
2.处理原始数据集并生成数据索引列表
3.返回子数据集的相关参数(子集行人ID数量,子集图片数量)
"""
class Market1501(object):
dataset_dir='data/Market-1501-v15.09.15'#指定数据集路径
def __init__(self,root='./',**kwargs):
self.dataset_dir=osp.join(root,self.dataset_dir)
self.train_dir=osp.join(self.dataset_dir,'bounding_box_train')#训练集
self.gallery_dir=osp.join(self.dataset_dir,'bounding_box_test')#测试集
self.query_dir=osp.join(self.dataset_dir,'query')#查询集
train, num_train_pids, num_train_imgs=self._process_dir(self.train_dir,relabel=True)
query, num_query_pids, num_query_imgs=self._process_dir(self.query_dir,relabel=False)
gallery, num_gallery_pids, num_gallery_imgs=self._process_dir(self.gallery_dir,relabel=False)
num_total_pids=num_train_pids+num_query_pids
num_total_imgs=num_train_imgs+num_query_imgs
print("=> Market1501 loaded")
print("------------------------------------------------------------------------")
print(" subset: train \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_train_pids,num_train_imgs))
print(" subset: query \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_query_pids,num_query_imgs))
print(" subset: gallery \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_gallery_pids,num_gallery_imgs))
print("------------------------------------------------------------------------")
print(" total \t\t\t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_total_pids,num_total_imgs))
print("------------------------------------------------------------------------")
self.train=train
self.query=query
self.gallery=gallery
self.num_train_pids=num_train_pids
self.num_query_pids=num_query_pids
self.num_gallery_pids=num_gallery_pids
def _process_dir(self,dir_path,relabel=False):
img_paths=glob.glob(osp.join(dir_path,'*.jpg'))
pid_container=set()
for img_path in img_paths:
pid=int(img_path.split("\\")[-1].split("_")[0])
if pid==-1:continue
pid_container.add(pid)
pid2label={pid:label for label,pid in enumerate(pid_container)}
dataset=[]
for img_path in img_paths:
str_list=img_path.split("\\")[-1].split("_")
pid=int(str_list[0])
cid=int(str_list[1][1:2])
if pid==-1:continue
assert 0<=pid <=1501
assert 1<=cid<=6
cid+=-1
if relabel:
pid=pid2label[pid]
dataset.append((img_path,pid,cid))
num_pids=len(pid_container)
num_imgs=len(img_paths)
#返回一个数据为三元组(图片地址,行人ID,摄像机ID)的索引列表形式的数据集,行人ID数量,图片数量
return dataset, num_pids, num_imgs
if __name__=='__main__':
data=Market1501()