CLIP-ReID代码解读五——dataset文件夹(数据集读取和处理)

market1501.py

# encoding: utf-8
import glob
import re
import os.path as osp
from .bases import BaseImageDataset
from collections import defaultdict
import pickle

class Market1501(BaseImageDataset):
    """
    Market1501
    Reference:
    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
    URL: http://www.liangzheng.org/Project/project_reid.html

	- 1501个身份
    - 6个摄像头
    - 32668张图片
    - DPM检测器代替手工框出行人
    - 500K张干扰图片
    - 每一个身份有多个query
    - 每一个query平均对应14.8个gallery

    Dataset statistics:
    # identities: 1501 (+1 for background)
    # images: 12936 (train) + 3368 (query) + 15913 (gallery)
    """
    # 数据集目录名称
    dataset_dir = 'Market-1501-v15.09.15'

    def __init__(self, root='', verbose=True, pid_begin=0, **kwargs):
        super(Market1501, self).__init__()
        # 设置数据集目录路径
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        # 检查数据集文件是否存在
        self._check_before_run()
        self.pid_begin = pid_begin
        # 处理训练集、查询集和图库集
        train = self._process_dir(self.train_dir, relabel=True)
        query = self._process_dir(self.query_dir, relabel=False)
        gallery = self._process_dir(self.gallery_dir, relabel=False)

        if verbose:
            print("=> Market1501 loaded")
            self.print_dataset_statistics(train, query, gallery)

        # 将处理后的数据集保存到类属性中
        self.train = train
        self.query = query
        self.gallery = gallery

        # 获取数据集的统计信息
        self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
        self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)

    def _check_before_run(self):
        """检查数据集文件是否存在"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    def _process_dir(self, dir_path, relabel=False):
        # 获取目录下所有.jpg格式的图像文件路径
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        # 使用正则表达式解析图像文件名,提取行人ID和摄像机ID
        pattern = re.compile(r'([-\d]+)_c(\d)')
        
        pid_container = set()
        # 遍历所有图像路径
        for img_path in sorted(img_paths):
            # 从文件名中提取行人ID和摄像机ID
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # 忽略无效的图像
            pid_container.add(pid)
        
        # 如果需要重新标注,则创建一个从行人ID到标签的映射字典
        pid2label = {pid: label for label, pid in enumerate(pid_container)}
        # pid_container: {2, 7, 10, 11, 12, 20, 22, 23......}
        # pid2label: {2: 0, 7:1, 10:2, 11:3, 12:4, 20:5, 22:6, 23:7......}
        
        # 创建一个列表,存储处理后的数据集
        dataset = []

		"""
        '0002_c1s1_000451_03.jpg'
        0002 是行人 ID,Market 1501 有 1501 个行人,故行人 ID 范围为 0001-1501
        c1 是摄像头编号(camera 1),表明图片采集自第1个摄像头,一共有 6 个摄像头
        s1 是视频的第一个片段(sequece1),一个视频包含若干个片段
        000451 是视频的第 451 帧图片,表明行人出现在该帧图片中
        03 代表第 451 帧图片上的第三个检测框,DPM 检测器可能在一帧图片上生成多个检测框
        """

        # 再次遍历所有图像路径
        for img_path in sorted(img_paths):
            # 从文件名中提取行人ID和摄像机ID
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # 忽略无效的图像
            assert 0 <= pid <= 1501  # 确保行人ID有效
            assert 1 <= camid <= 6  # 确保摄像机ID有效
            camid -= 1  # 摄像机ID从0开始
            if relabel: 
                # 如果需要重新标注,使用新的行人ID
                pid = pid2label[pid]

            # 将图像路径、行人ID、摄像机ID和初始标签(0)添加到数据集中
            dataset.append((img_path, self.pid_begin + pid, camid, 0))
        return dataset

代码功能简介

这段代码定义了一个用于处理 Market-1501 数据集的类 Market1501,该数据集常用于行人重识别任务。该类继承自 BaseImageDataset 并包含以下功能:

  1. 初始化数据集路径。
  2. 检查数据集文件是否存在。
  3. 处理图像目录,提取图像路径、行人ID、摄像机ID等信息,并根据需要重新标注行人ID。
  4. 打印数据集统计信息。
  5. 计算并存储训练集、查询集和图库集的图像数量和行人数量等信息。

其中pattern = re.compile(r'([-\d]+)_c(\d)') 是一条用来创建正则表达式模式的语句,用于从文件名中提取行人ID(pid)和摄像机ID(camid)。

详细解释

  1. re.compile(r'([-\d]+)_c(\d)')

    • re.compile 是 Python 的 re 模块中的一个函数,用于编译正则表达式模式。
    • r'([-\d]+)_c(\d)' 是一个原始字符串(raw string),其中的正则表达式模式用于匹配特定格式的字符串。
  2. 正则表达式模式 r'([-\d]+)_c(\d)'

    • r 表示原始字符串,不会对字符串中的反斜杠进行转义。
    • ([-\d]+):这是第一个捕获组,用于匹配行人ID(pid)。
      • [-\d]:匹配一个字符,可以是数字(\d)或连字符(-)。
      • +:表示前面的字符类可以重复一次或多次,因此 [-\d]+ 可以匹配一个或多个数字或连字符的组合。
    • _c:匹配字符 _c,这是固定字符,表示摄像机ID前的标识。
    • (\d):这是第二个捕获组,用于匹配摄像机ID(camid)。
      • \d:匹配一个数字字符。

举例说明

假设有以下文件名:

  • 0001_c1.jpg
  • 0123_c2.jpg
  • -001_c3.jpg

对于每个文件名,正则表达式 ([-\d]+)_c(\d) 将进行以下匹配:

  • 0001_c1.jpg

    • ([-\d]+) 匹配 0001,即行人ID。
    • (\d) 匹配 1,即摄像机ID。
  • 0123_c2.jpg

    • ([-\d]+) 匹配 0123,即行人ID。
    • (\d) 匹配 2,即摄像机ID。
  • -001_c3.jpg

    • ([-\d]+) 匹配 -001,即行人ID。
    • (\d) 匹配 3,即摄像机ID。

使用示例

下面是一个示例代码,展示如何使用这个正则表达式模式从文件名中提取行人ID和摄像机ID:

import re

# 编译正则表达式模式
pattern = re.compile(r'([-\d]+)_c(\d)')

# 示例文件名
file_names = ['0001_c1.jpg', '0123_c2.jpg', '-001_c3.jpg']

for file_name in file_names:
    match = pattern.search(file_name)
    if match:
        pid, camid = match.groups()
        print(f"File: {file_name}, Person ID: {pid}, Camera ID: {camid}")

输出结果

File: 0001_c1.jpg, Person ID: 0001, Camera ID: 1
File: 0123_c2.jpg, Person ID: 0123, Camera ID: 2
File: -001_c3.jpg, Person ID: -001, Camera ID: 3

这个正则表达式模式 ([-\d]+)_c(\d) 的作用就是从符合特定格式的字符串(如文件名)中提取出行人ID和摄像机ID,以便后续处理。

make_dataloader.py

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader

from .bases import ImageDataset
from timm.data.random_erasing import RandomErasing
from .sampler import RandomIdentitySampler
from .dukemtmcreid import DukeMTMCreID
from .market1501 import Market1501
from .msmt17 import MSMT17
from .sampler_ddp import RandomIdentitySampler_DDP
import torch.distributed as dist
from .occ_duke import OCC_DukeMTMCreID
from .vehicleid import VehicleID
from .veri import VeRi

# 数据集类的工厂字典
__factory = {
    'market1501': Market1501,
    'dukemtmc': DukeMTMCreID,
    'msmt17': MSMT17,
    'occ_duke': OCC_DukeMTMCreID,
    'veri': VeRi,
    'VehicleID': VehicleID
}


# 将一个 batch 的数据进行整理和处理,以便于模型训练。它将从数据集中提取的元素(如图像、标签等)进行打包,并转换为适合模型输入的格式。
def train_collate_fn(batch):
    """
    # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果
    """
    # zip(*batch) 将 batch 中的每个元素按位置进行分组。假设 batch 中的每个元素是 (img, pid, camid, viewid, _),zip(*batch) 将返回五个元组
    imgs, pids, camids, viewids , _ = zip(*batch)
    pids = torch.tensor(pids, dtype=torch.int64)
    viewids = torch.tensor(viewids, dtype=torch.int64)
    camids = torch.tensor(camids, dtype=torch.int64)
    # 将 pids、camids 和 viewids 转换为 PyTorch 的 int64 类型张量。这是因为模型训练中需要这些标签为张量形式,便于计算和操作。
    return torch.stack(imgs, dim=0), pids, camids, viewids,
    # torch.stack(imgs, dim=0) 将 imgs 列表中的所有图像张量沿着新的维度(这里是第0维)进行堆叠,形成一个 batch 的图像张量。
    """
    最终返回:
    torch.stack(imgs, dim=0):一个包含所有图像的张量,形状为 (batch_size, C, H, W)。
    torch.tensor(pids, dtype=torch.int64):一个包含所有行人ID的张量,形状为 (batch_size,)。
    torch.tensor(camids, dtype=torch.int64):一个包含所有摄像机ID的张量,形状为 (batch_size,)。
    torch.tensor(viewids, dtype=torch.int64):一个包含所有视角ID的张量,形状为 (batch_size,)。
    """

def val_collate_fn(batch):
    """
    该函数用于将验证集中的一个batch进行整理和处理。
    """
    # 将批次中的每个元素解包
    imgs, pids, camids, viewids, img_paths = zip(*batch)
    # 将视角ID和摄像机ID转换为张量
    viewids = torch.tensor(viewids, dtype=torch.int64)
    camids_batch = torch.tensor(camids, dtype=torch.int64)
    # 返回整理后的图像张量、标签张量和图像路径
    return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths

def make_dataloader(cfg):
    """
    创建数据加载器,应用数据预处理和增强操作。
    """
    # 定义训练集的图像预处理和数据增强操作
    train_transforms = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3),  # 调整图像大小
        T.RandomHorizontalFlip(p=cfg.INPUT.PROB),  # 随机水平翻转
        T.Pad(cfg.INPUT.PADDING),  # 填充
        T.RandomCrop(cfg.INPUT.SIZE_TRAIN),  # 随机裁剪
        T.ToTensor(),  # 转换为张量
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),  # 归一化
        RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'),  # 随机擦除
    ])

    # 定义验证集的图像预处理操作
    val_transforms = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TEST),  # 调整图像大小
        T.ToTensor(),  # 转换为张量
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)  # 归一化
    ])

    num_workers = cfg.DATALOADER.NUM_WORKERS  # 设置数据加载的工作线程数量

    # 根据配置中的数据集名称实例化相应的数据集类
    dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR)
    
    # 创建训练集和验证集的ImageDataset实例
    train_set = ImageDataset(dataset.train, train_transforms)
    train_set_normal = ImageDataset(dataset.train, val_transforms)
    num_classes = dataset.num_train_pids  # 获取训练集中的行人类别数
    cam_num = dataset.num_train_cams  # 获取训练集中的摄像机数量
    view_num = dataset.num_train_vids  # 获取训练集中的视角数量

    # 根据配置选择数据加载的采样器
    if 'triplet' in cfg.DATALOADER.SAMPLER:
        if cfg.MODEL.DIST_TRAIN:
            # 分布式训练配置
            print('DIST_TRAIN START')
            mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size()
            data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
            batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
            train_loader = torch.utils.data.DataLoader(
                train_set,
                num_workers=num_workers,
                batch_sampler=batch_sampler,
                collate_fn=train_collate_fn,
                pin_memory=True,
            )
        else:
            train_loader = DataLoader(
                train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
                sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
                num_workers=num_workers, collate_fn=train_collate_fn
            )
    elif cfg.DATALOADER.SAMPLER == 'softmax':
        print('using softmax sampler')
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
            collate_fn=train_collate_fn
        )
    else:
        print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER))

    # 创建验证集的数据加载器
    val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)

    val_loader = DataLoader(
        val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
        collate_fn=val_collate_fn
    )
    # 创建一个标准化的训练集加载器
    train_loader_normal = DataLoader(
        train_set_normal, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
        collate_fn=val_collate_fn
    )
    # 返回训练集加载器,标准化的训练集加载器,验证集加载器,查询集大小,类别数,摄像机数,视角数
    return train_loader, train_loader_normal, val_loader, len(dataset.query), num_classes, cam_num, view_num

代码功能简介

这段代码的主要功能是定义数据加载和预处理流程,包括训练数据和验证数据的加载、数据增强、采样等。具体包含以下几个部分:

  1. 导入必要的库和模块:包括 PyTorch、torchvision.transforms、数据集类、采样器、数据增强等。
  2. 定义数据集类的工厂字典 __factory:用于根据配置动态实例化不同的数据集类。
  3. 定义数据整理函数 train_collate_fnval_collate_fn:用于将一个 batch 中的数据进行整理和处理。
  4. 定义函数 make_dataloader:用于创建数据加载器,包括训练数据和验证数据的加载和处理。

主要函数和功能

train_collate_fn 函数
  • 输入:一个 batch 的数据(list)。
  • 功能:将 batch 中的数据解包,并转换为张量。
  • 输出:图像张量、行人ID张量、摄像机ID张量和视角ID张量。
val_collate_fn 函数
  • 输入:一个 batch 的数据(list)。
  • 功能:将 batch 中的数据解包,并转换为张量。
  • 输出:图像张量、行人ID张量、摄像机ID张量、视角ID张量和图像路径。
make_dataloader 函数
  • 输入:配置对象 cfg
  • 功能:
    1. 定义训练集和验证集的图像预处理和数据增强操作。
    2. 根据配置中的数据集名称实例化相应的数据集类。
    3. 创建训练集和验证集的 ImageDataset 实例。
    4. 根据配置选择数据加载的采样器,创建训练集数据加载器。
    5. 创建验证集的数据加载器。
    6. 返回训练集加载器、标准化的训练集加载器、验证集加载器、查询集大小、类别数、摄像机数、视角数。

其中 dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 是在一个配置驱动的框架中用于实例化特定数据集类的代码。以下是详细解释:

代码分解与解释

  1. __factory:

    • __factory 通常是一个字典,用于存储不同数据集类的构造函数。
    • 这个字典的键是数据集的名称,值是对应的数据集类或构造函数。
  2. cfg.DATASETS.NAMES:

    • cfg 是一个配置对象,包含了数据集和其他配置项。
    • cfg.DATASETS.NAMES 是配置对象中指定数据集名称的属性,通常是一个字符串,表示要使用的数据集名称。
    • 例如,cfg.DATASETS.NAMES 可能是 "Market1501"
  3. cfg.DATASETS.ROOT_DIR:

    • cfg.DATASETS.ROOT_DIR 是配置对象中指定数据集根目录的属性,通常是一个字符串,表示数据集所在的文件路径。
    • 例如,cfg.DATASETS.ROOT_DIR 可能是 "/path/to/dataset/root"
  4. 实例化数据集类:

    • __factory[cfg.DATASETS.NAMES] 根据数据集名称从 __factory 字典中获取相应的数据集类或构造函数。
    • __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 调用该构造函数,并传入 root=cfg.DATASETS.ROOT_DIR 参数,实例化数据集对象。
    • 例如,如果 cfg.DATASETS.NAMES"Market1501"__factory["Market1501"] 可能是 Market1501 类,那么 __factory["Market1501"](root=cfg.DATASETS.ROOT_DIR) 就相当于 Market1501(root=cfg.DATASETS.ROOT_DIR)

示例

假设 __factory 定义如下:

__factory = {
    'Market1501': Market1501,
    'DukeMTMC': DukeMTMC,
    # 其他数据集类
}

配置对象 cfg 的相关部分如下:

cfg = {
    'DATASETS': {
        'NAMES': 'Market1501',
        'ROOT_DIR': '/path/to/dataset/root'
    }
}

代码执行过程如下:

dataset = __factory[cfg['DATASETS']['NAMES']](root=cfg['DATASETS']['ROOT_DIR'])

实际上执行的是:

dataset = Market1501(root='/path/to/dataset/root')

这样,dataset 就是一个 Market1501 数据集类的实例,已经被初始化并准备好使用。

总结

这句代码的主要作用是根据配置文件中的数据集名称和根目录,动态地实例化相应的数据集类对象。通过这种方式,可以灵活地切换不同的数据集,而无需修改代码中的数据集类名称,符合配置驱动的设计思想。

  • 14
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yiruzhao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值