market1501.py
# encoding: utf-8
import glob
import re
import os.path as osp
from .bases import BaseImageDataset
from collections import defaultdict
import pickle
class Market1501(BaseImageDataset):
"""
Market1501
Reference:
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
URL: http://www.liangzheng.org/Project/project_reid.html
- 1501个身份
- 6个摄像头
- 32668张图片
- DPM检测器代替手工框出行人
- 500K张干扰图片
- 每一个身份有多个query
- 每一个query平均对应14.8个gallery
Dataset statistics:
# identities: 1501 (+1 for background)
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
"""
# 数据集目录名称
dataset_dir = 'Market-1501-v15.09.15'
def __init__(self, root='', verbose=True, pid_begin=0, **kwargs):
super(Market1501, self).__init__()
# 设置数据集目录路径
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
# 检查数据集文件是否存在
self._check_before_run()
self.pid_begin = pid_begin
# 处理训练集、查询集和图库集
train = self._process_dir(self.train_dir, relabel=True)
query = self._process_dir(self.query_dir, relabel=False)
gallery = self._process_dir(self.gallery_dir, relabel=False)
if verbose:
print("=> Market1501 loaded")
self.print_dataset_statistics(train, query, gallery)
# 将处理后的数据集保存到类属性中
self.train = train
self.query = query
self.gallery = gallery
# 获取数据集的统计信息
self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)
def _check_before_run(self):
"""检查数据集文件是否存在"""
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.train_dir):
raise RuntimeError("'{}' is not available".format(self.train_dir))
if not osp.exists(self.query_dir):
raise RuntimeError("'{}' is not available".format(self.query_dir))
if not osp.exists(self.gallery_dir):
raise RuntimeError("'{}' is not available".format(self.gallery_dir))
def _process_dir(self, dir_path, relabel=False):
# 获取目录下所有.jpg格式的图像文件路径
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
# 使用正则表达式解析图像文件名,提取行人ID和摄像机ID
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
# 遍历所有图像路径
for img_path in sorted(img_paths):
# 从文件名中提取行人ID和摄像机ID
pid, _ = map(int, pattern.search(img_path).groups())
if pid == -1: continue # 忽略无效的图像
pid_container.add(pid)
# 如果需要重新标注,则创建一个从行人ID到标签的映射字典
pid2label = {pid: label for label, pid in enumerate(pid_container)}
# pid_container: {2, 7, 10, 11, 12, 20, 22, 23......}
# pid2label: {2: 0, 7:1, 10:2, 11:3, 12:4, 20:5, 22:6, 23:7......}
# 创建一个列表,存储处理后的数据集
dataset = []
"""
'0002_c1s1_000451_03.jpg'
0002 是行人 ID,Market 1501 有 1501 个行人,故行人 ID 范围为 0001-1501
c1 是摄像头编号(camera 1),表明图片采集自第1个摄像头,一共有 6 个摄像头
s1 是视频的第一个片段(sequece1),一个视频包含若干个片段
000451 是视频的第 451 帧图片,表明行人出现在该帧图片中
03 代表第 451 帧图片上的第三个检测框,DPM 检测器可能在一帧图片上生成多个检测框
"""
# 再次遍历所有图像路径
for img_path in sorted(img_paths):
# 从文件名中提取行人ID和摄像机ID
pid, camid = map(int, pattern.search(img_path).groups())
if pid == -1: continue # 忽略无效的图像
assert 0 <= pid <= 1501 # 确保行人ID有效
assert 1 <= camid <= 6 # 确保摄像机ID有效
camid -= 1 # 摄像机ID从0开始
if relabel:
# 如果需要重新标注,使用新的行人ID
pid = pid2label[pid]
# 将图像路径、行人ID、摄像机ID和初始标签(0)添加到数据集中
dataset.append((img_path, self.pid_begin + pid, camid, 0))
return dataset
代码功能简介
这段代码定义了一个用于处理 Market-1501 数据集的类 Market1501
,该数据集常用于行人重识别任务。该类继承自 BaseImageDataset
并包含以下功能:
- 初始化数据集路径。
- 检查数据集文件是否存在。
- 处理图像目录,提取图像路径、行人ID、摄像机ID等信息,并根据需要重新标注行人ID。
- 打印数据集统计信息。
- 计算并存储训练集、查询集和图库集的图像数量和行人数量等信息。
其中pattern = re.compile(r'([-\d]+)_c(\d)')
是一条用来创建正则表达式模式的语句,用于从文件名中提取行人ID(pid)和摄像机ID(camid)。
详细解释
-
re.compile(r'([-\d]+)_c(\d)')
:re.compile
是 Python 的re
模块中的一个函数,用于编译正则表达式模式。r'([-\d]+)_c(\d)'
是一个原始字符串(raw string),其中的正则表达式模式用于匹配特定格式的字符串。
-
正则表达式模式
r'([-\d]+)_c(\d)'
:r
表示原始字符串,不会对字符串中的反斜杠进行转义。([-\d]+)
:这是第一个捕获组,用于匹配行人ID(pid)。[-\d]
:匹配一个字符,可以是数字(\d
)或连字符(-
)。+
:表示前面的字符类可以重复一次或多次,因此[-\d]+
可以匹配一个或多个数字或连字符的组合。
_c
:匹配字符_c
,这是固定字符,表示摄像机ID前的标识。(\d)
:这是第二个捕获组,用于匹配摄像机ID(camid)。\d
:匹配一个数字字符。
举例说明
假设有以下文件名:
0001_c1.jpg
0123_c2.jpg
-001_c3.jpg
对于每个文件名,正则表达式 ([-\d]+)_c(\d)
将进行以下匹配:
-
0001_c1.jpg
:([-\d]+)
匹配0001
,即行人ID。(\d)
匹配1
,即摄像机ID。
-
0123_c2.jpg
:([-\d]+)
匹配0123
,即行人ID。(\d)
匹配2
,即摄像机ID。
-
-001_c3.jpg
:([-\d]+)
匹配-001
,即行人ID。(\d)
匹配3
,即摄像机ID。
使用示例
下面是一个示例代码,展示如何使用这个正则表达式模式从文件名中提取行人ID和摄像机ID:
import re
# 编译正则表达式模式
pattern = re.compile(r'([-\d]+)_c(\d)')
# 示例文件名
file_names = ['0001_c1.jpg', '0123_c2.jpg', '-001_c3.jpg']
for file_name in file_names:
match = pattern.search(file_name)
if match:
pid, camid = match.groups()
print(f"File: {file_name}, Person ID: {pid}, Camera ID: {camid}")
输出结果
File: 0001_c1.jpg, Person ID: 0001, Camera ID: 1
File: 0123_c2.jpg, Person ID: 0123, Camera ID: 2
File: -001_c3.jpg, Person ID: -001, Camera ID: 3
这个正则表达式模式 ([-\d]+)_c(\d)
的作用就是从符合特定格式的字符串(如文件名)中提取出行人ID和摄像机ID,以便后续处理。
make_dataloader.py
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from .bases import ImageDataset
from timm.data.random_erasing import RandomErasing
from .sampler import RandomIdentitySampler
from .dukemtmcreid import DukeMTMCreID
from .market1501 import Market1501
from .msmt17 import MSMT17
from .sampler_ddp import RandomIdentitySampler_DDP
import torch.distributed as dist
from .occ_duke import OCC_DukeMTMCreID
from .vehicleid import VehicleID
from .veri import VeRi
# 数据集类的工厂字典
__factory = {
'market1501': Market1501,
'dukemtmc': DukeMTMCreID,
'msmt17': MSMT17,
'occ_duke': OCC_DukeMTMCreID,
'veri': VeRi,
'VehicleID': VehicleID
}
# 将一个 batch 的数据进行整理和处理,以便于模型训练。它将从数据集中提取的元素(如图像、标签等)进行打包,并转换为适合模型输入的格式。
def train_collate_fn(batch):
"""
# collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果
"""
# zip(*batch) 将 batch 中的每个元素按位置进行分组。假设 batch 中的每个元素是 (img, pid, camid, viewid, _),zip(*batch) 将返回五个元组
imgs, pids, camids, viewids , _ = zip(*batch)
pids = torch.tensor(pids, dtype=torch.int64)
viewids = torch.tensor(viewids, dtype=torch.int64)
camids = torch.tensor(camids, dtype=torch.int64)
# 将 pids、camids 和 viewids 转换为 PyTorch 的 int64 类型张量。这是因为模型训练中需要这些标签为张量形式,便于计算和操作。
return torch.stack(imgs, dim=0), pids, camids, viewids,
# torch.stack(imgs, dim=0) 将 imgs 列表中的所有图像张量沿着新的维度(这里是第0维)进行堆叠,形成一个 batch 的图像张量。
"""
最终返回:
torch.stack(imgs, dim=0):一个包含所有图像的张量,形状为 (batch_size, C, H, W)。
torch.tensor(pids, dtype=torch.int64):一个包含所有行人ID的张量,形状为 (batch_size,)。
torch.tensor(camids, dtype=torch.int64):一个包含所有摄像机ID的张量,形状为 (batch_size,)。
torch.tensor(viewids, dtype=torch.int64):一个包含所有视角ID的张量,形状为 (batch_size,)。
"""
def val_collate_fn(batch):
"""
该函数用于将验证集中的一个batch进行整理和处理。
"""
# 将批次中的每个元素解包
imgs, pids, camids, viewids, img_paths = zip(*batch)
# 将视角ID和摄像机ID转换为张量
viewids = torch.tensor(viewids, dtype=torch.int64)
camids_batch = torch.tensor(camids, dtype=torch.int64)
# 返回整理后的图像张量、标签张量和图像路径
return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths
def make_dataloader(cfg):
"""
创建数据加载器,应用数据预处理和增强操作。
"""
# 定义训练集的图像预处理和数据增强操作
train_transforms = T.Compose([
T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), # 调整图像大小
T.RandomHorizontalFlip(p=cfg.INPUT.PROB), # 随机水平翻转
T.Pad(cfg.INPUT.PADDING), # 填充
T.RandomCrop(cfg.INPUT.SIZE_TRAIN), # 随机裁剪
T.ToTensor(), # 转换为张量
T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), # 归一化
RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), # 随机擦除
])
# 定义验证集的图像预处理操作
val_transforms = T.Compose([
T.Resize(cfg.INPUT.SIZE_TEST), # 调整图像大小
T.ToTensor(), # 转换为张量
T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) # 归一化
])
num_workers = cfg.DATALOADER.NUM_WORKERS # 设置数据加载的工作线程数量
# 根据配置中的数据集名称实例化相应的数据集类
dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR)
# 创建训练集和验证集的ImageDataset实例
train_set = ImageDataset(dataset.train, train_transforms)
train_set_normal = ImageDataset(dataset.train, val_transforms)
num_classes = dataset.num_train_pids # 获取训练集中的行人类别数
cam_num = dataset.num_train_cams # 获取训练集中的摄像机数量
view_num = dataset.num_train_vids # 获取训练集中的视角数量
# 根据配置选择数据加载的采样器
if 'triplet' in cfg.DATALOADER.SAMPLER:
if cfg.MODEL.DIST_TRAIN:
# 分布式训练配置
print('DIST_TRAIN START')
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size()
data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
train_loader = torch.utils.data.DataLoader(
train_set,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=train_collate_fn,
pin_memory=True,
)
else:
train_loader = DataLoader(
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
num_workers=num_workers, collate_fn=train_collate_fn
)
elif cfg.DATALOADER.SAMPLER == 'softmax':
print('using softmax sampler')
train_loader = DataLoader(
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
collate_fn=train_collate_fn
)
else:
print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER))
# 创建验证集的数据加载器
val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
val_loader = DataLoader(
val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
collate_fn=val_collate_fn
)
# 创建一个标准化的训练集加载器
train_loader_normal = DataLoader(
train_set_normal, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
collate_fn=val_collate_fn
)
# 返回训练集加载器,标准化的训练集加载器,验证集加载器,查询集大小,类别数,摄像机数,视角数
return train_loader, train_loader_normal, val_loader, len(dataset.query), num_classes, cam_num, view_num
代码功能简介
这段代码的主要功能是定义数据加载和预处理流程,包括训练数据和验证数据的加载、数据增强、采样等。具体包含以下几个部分:
- 导入必要的库和模块:包括 PyTorch、torchvision.transforms、数据集类、采样器、数据增强等。
- 定义数据集类的工厂字典
__factory
:用于根据配置动态实例化不同的数据集类。 - 定义数据整理函数
train_collate_fn
和val_collate_fn
:用于将一个 batch 中的数据进行整理和处理。 - 定义函数
make_dataloader
:用于创建数据加载器,包括训练数据和验证数据的加载和处理。
主要函数和功能
train_collate_fn
函数
- 输入:一个 batch 的数据(list)。
- 功能:将 batch 中的数据解包,并转换为张量。
- 输出:图像张量、行人ID张量、摄像机ID张量和视角ID张量。
val_collate_fn
函数
- 输入:一个 batch 的数据(list)。
- 功能:将 batch 中的数据解包,并转换为张量。
- 输出:图像张量、行人ID张量、摄像机ID张量、视角ID张量和图像路径。
make_dataloader
函数
- 输入:配置对象
cfg
。 - 功能:
- 定义训练集和验证集的图像预处理和数据增强操作。
- 根据配置中的数据集名称实例化相应的数据集类。
- 创建训练集和验证集的
ImageDataset
实例。 - 根据配置选择数据加载的采样器,创建训练集数据加载器。
- 创建验证集的数据加载器。
- 返回训练集加载器、标准化的训练集加载器、验证集加载器、查询集大小、类别数、摄像机数、视角数。
其中 dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR)
是在一个配置驱动的框架中用于实例化特定数据集类的代码。以下是详细解释:
代码分解与解释
-
__factory
:__factory
通常是一个字典,用于存储不同数据集类的构造函数。- 这个字典的键是数据集的名称,值是对应的数据集类或构造函数。
-
cfg.DATASETS.NAMES
:cfg
是一个配置对象,包含了数据集和其他配置项。cfg.DATASETS.NAMES
是配置对象中指定数据集名称的属性,通常是一个字符串,表示要使用的数据集名称。- 例如,
cfg.DATASETS.NAMES
可能是"Market1501"
。
-
cfg.DATASETS.ROOT_DIR
:cfg.DATASETS.ROOT_DIR
是配置对象中指定数据集根目录的属性,通常是一个字符串,表示数据集所在的文件路径。- 例如,
cfg.DATASETS.ROOT_DIR
可能是"/path/to/dataset/root"
。
-
实例化数据集类:
__factory[cfg.DATASETS.NAMES]
根据数据集名称从__factory
字典中获取相应的数据集类或构造函数。__factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR)
调用该构造函数,并传入root=cfg.DATASETS.ROOT_DIR
参数,实例化数据集对象。- 例如,如果
cfg.DATASETS.NAMES
是"Market1501"
,__factory["Market1501"]
可能是Market1501
类,那么__factory["Market1501"](root=cfg.DATASETS.ROOT_DIR)
就相当于Market1501(root=cfg.DATASETS.ROOT_DIR)
。
示例
假设 __factory
定义如下:
__factory = {
'Market1501': Market1501,
'DukeMTMC': DukeMTMC,
# 其他数据集类
}
配置对象 cfg
的相关部分如下:
cfg = {
'DATASETS': {
'NAMES': 'Market1501',
'ROOT_DIR': '/path/to/dataset/root'
}
}
代码执行过程如下:
dataset = __factory[cfg['DATASETS']['NAMES']](root=cfg['DATASETS']['ROOT_DIR'])
实际上执行的是:
dataset = Market1501(root='/path/to/dataset/root')
这样,dataset
就是一个 Market1501
数据集类的实例,已经被初始化并准备好使用。
总结
这句代码的主要作用是根据配置文件中的数据集名称和根目录,动态地实例化相应的数据集类对象。通过这种方式,可以灵活地切换不同的数据集,而无需修改代码中的数据集类名称,符合配置驱动的设计思想。