源码URL:
https://github.com/michuanhaohao/deep-person-reid/blob/master/data_manager.py
数据集处理,数据路径处理。前100行源码注释
from __future__ import print_function, absolute_import # 如果你的python版本是python2.X,你也得按照python3.X那样使用这些函数
import os
import os.path as osp
import numpy as np
import glob
import re # 正则表达式
from utils import mkdir_if_missing, write_json, read_json
from IPython import embed
class Market1501(object):
"""
Market1501
Reference:
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
URL: http://www.liangzheng.org/Project/project_reid.html
Dataset statistics:
# identities: 1501 (+1 for background)
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
"""
dataset_dir = 'F:/Market-1501/Market-1501-v15.09.15' # 放你自己的data路径
def __init__(self, root='data', **kwargs):
self.dataset_dir = os.path.join(root, self.dataset_dir) # 绝对路径,将root和self.dataset_dir连起来
self.train_dir = os.path.join(self.dataset_dir, 'bounding_box_test') # 后面原理同上
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
self._check_before_run() # 调用检查函数,检查路径是否有效,否则报错.
# 需要什么?文件路径,标注信息(ID,CAMERA_ID),图片数量
train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) # 做训练集的路径处理
query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) # 原理同上
gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) # 原理同上
num_total_pids = num_train_pids + num_query_pids # 总共多少个id = 训练的id和测试的id加起来
num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs # 总共多少个图片
print("=> Market1501 loaded")
print("Dataset statistics:")
print(" ------------------------------")
print(" subset | # ids | # images")
print(" ------------------------------")
print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
print(" ------------------------------")
print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
print(" ------------------------------")
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids = num_train_pids
self.num_query_pids = num_query_pids
self.num_gallery_pids = num_gallery_pids
def _check_before_run(self):
if not os.path.exists(self.dataset_dir): # 判断路径是否存在
raise RuntimeError("'{}' is not available".format(self.dataset_dir)) # 如果有什么问题,即若路径不存在,在这里报错, 打印错误信息
if not osp.exists(self.train_dir): # 后面原理同上
raise RuntimeError("'{}' is not available".format(self.train_dir))
if not osp.exists(self.query_dir):
raise RuntimeError("'{}' is not available".format(self.query_dir))
if not osp.exists(self.gallery_dir):
raise RuntimeError("'{}' is not available".format(self.gallery_dir))
def _process_dir(self, dir_path, relabel=False): # 路径处理函数
img_paths = glob.glob(osp.join(dir_path, '*.jpg')) # 将'*.jpg'与dir_path结合, 后对dir_path/*.jpg中的通配符*, 返回所有匹配的文件名和路径名列表
pattern = re.compile(r'([-\d]+)_c(\d)') # relabel,重标注ID
pid_container = set() # pid = person id。 set()有去重功能
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups()) # pattern.search(string, pos, endpos).group(), 令string在pos和endpos区间内扫描字符, 不指定
# pos和endpos则扫描整个string. 并最后返回和pattern匹配的对象
if pid == -1: continue # Marked1501有一些垃圾数据,是单纯的背景,id为-1。若遇到直接跳过
pid_container.add(pid) # 放入集合pid_container, 751个人则集合有751个id, id编号区间为0-1501(不是全部)
pid2label = {pid:label for label, pid in enumerate(pid_container)} # 集合存入映射-> 在集合中的索引(新id):pid(原id)
dataset = []
for img_path in img_paths: # 下面原理基本同上
pid, camid = map(int, pattern.search(img_path).groups()) # 这次把camera id 也导出
if pid == -1: continue
assert 0 <= pid <= 1501 # 确保pid位于0-1501
assert 1 <= camid <= 6 # 确保camera id位于1-6
camid -= 1 # 把1-6的camera id处理成0-5
if relabel:
pid = pid2label[pid] # 原pid(0-1501)换新pid(0-750)
dataset.append((img_path, pid, camid)) # 把路径,person id,camera id包在一起, 存入dataset
num_pids = len(pid_container) # 得到person id 的个数
num_imgs = len(img_paths) # 得到image的数量
return dataset, num_pids, num_imgs
if __name__ == '__main__':
data = Marked1501(root='F:/Market-1501/Market-1501-v15.09.15') # 调用
运行结果如下: