行人重识别02-06:fast-reid(BoT)-pytorch编程规范(fast-reid为例)3-迭代器构建,数据加载-1

以下链接是个人关于fast-reid(BoT行人重识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
行人重识别02-00:fast-reid(BoT)-目录-史上最新无死角讲解

极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解

前言

我们在 fastreid/engine/defaults.py 文件中可以看到类 class DefaultTrainer(SimpleTrainer),然后找到如下代码:

    def __init__(self, cfg):
    	......
    	# 创建训练数据及迭代器
        data_loader = self.build_train_loader(cfg)

build_train_loader的实现如下:

    @classmethod
    def build_train_loader(cls, cfg):
        """
        构建一个训练数据迭代器
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        logger = logging.getLogger(__name__)
        logger.info("Prepare training set")
        return build_reid_train_loader(cfg)

其上的 build_reid_train_loader(cfg),就是该小节重点讲解的内容。本人注释如下(粗略注释,详细的在后面):

def build_reid_train_loader(cfg):
    # 构建数据迭代器
    cfg = cfg.clone()
    # cfg配置解冻
    cfg.defrost()

    # 训练数目列表
    train_items = list()
    # 循环加载多个数据集
    for d in cfg.DATASETS.NAMES:
        # 根据数据集名称,创建对应数据集迭代的类,本人调试为类 fastreid.data.datasets.market1501.Market1501对象
        dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
        # 如果为主线程,则显示训练信息
        if comm.is_main_process():
            dataset.show_train()
        # dataset.train包含的都是训练数据的信息,示例部分数据如下:
        # 图片路径                                                                      身份ID          摄像头编号
        #'datasets/Market-1501-v15.09.15/bounding_box_train/0309_c3s2_037562_02.jpg', 'market1501_309', 2
        # 'datasets/Market-1501-v15.09.15/bounding_box_train/0208_c1s1_045426_08.jpg', 'market1501_208', 0
        # 要知道具体的过程,需要查看类fastreid.data.datasets.market1501.Market1501
        train_items.extend(dataset.train)

    # 获得每个epoch需要迭代的次数
    iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH

    # 更改数据总共需要迭代的次数
    cfg.SOLVER.MAX_ITER *= iters_per_epoch
    # 构建train_transforms,其中包括了数据预处理,数据增强,剪切等等
    train_transforms = build_transforms(cfg, is_train=True)
    # CommDataset继承于Dataset,并且其中实现了__getitem__函数,
    # train_items为一个列表,还包含了所有训练数据的信息
    train_set = CommDataset(train_items, train_transforms, relabel=True)
    # 获得线程数目
    num_workers = cfg.DATALOADER.NUM_WORKERS
    # 获得实例数目(每个身份采集多少张图片,论文中的K)
    num_instance = cfg.DATALOADER.NUM_INSTANCE
    # 获得每个GPU的batch_size
    mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()


    if cfg.DATALOADER.PK_SAMPLER:
        # 如果使用简单的采样方式
        if cfg.DATALOADER.NAIVE_WAY:
            data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,
                                                        cfg.SOLVER.IMS_PER_BATCH, num_instance)
        # 如果使用平衡的采样方式
        else:
            data_sampler = samplers.BalancedIdentitySampler(train_set.img_items,
                                                            cfg.SOLVER.IMS_PER_BATCH, num_instance)
    else:
        data_sampler = samplers.TrainingSampler(len(train_set))
    # 构建batch数据迭代采样器,指定加载数据的线程数目等
    batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
    )
    return train_loader

其上的代码中,有如下几个地方需要重点讲解:

    for d in cfg.DATASETS.NAMES:
        # 根据数据集名称,创建对应数据集迭代的类,本人调试为类 fastreid.data.datasets.market1501.Market1501对象
        dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
        train_items.extend(dataset.train)
    train_set = CommDataset(train_items, train_transforms, relabel=True)

data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,
                                                        cfg.SOLVER.IMS_PER_BATCH, num_instance)

CommDataset

其上的DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL),就是
fastreid/data/datasets/market1501.py文件中类class Market1501(ImageDataset)创建的对象,我们下个小节进行讲解。只要知道其中有个重要的属性 d a t a s e t . t r a i n \color{red}{dataset.train} dataset.train(一个列表) 包含了一个数据集所有训练的信息,然后多个数据的信息共同被添加到 train_items 之中。我们这篇博客先来看看:

train_set = CommDataset(train_items, train_transforms, relabel=True)

本人注释如下:

class CommDataset(Dataset):
    """Image Person ReID Dataset"""

    def __init__(self, img_items, transform=None, relabel=True):
        self.img_items = img_items
        self.transform = transform
        self.relabel = relabel
        self.pid_dict = {}
        # 如果重新刷新标签
        if self.relabel:
            pids = list()
            # 获得每张图像的信息
            for i, item in enumerate(img_items):
                # 如果id已经出现在pids,表示重复,则跳过
                if item[1] in pids: continue
                # 否则添加到pids之中
                pids.append(item[1])
            # 获得所有的pids
            self.pids = pids
            # 为每个ID分配一个序列号
            self.pid_dict = dict([(p, i) for i, p in enumerate(self.pids)])
            pass

    def __len__(self):
        return len(self.img_items)

    def __getitem__(self, index):
        # 根据index,获得其图像的img_path, pid, camid信息
        img_path, pid, camid = self.img_items[index]
        # 根据图像路径读取图像像素
        img = read_image(img_path)
        # 如果self.transform不为none,则进行transform处理
        if self.transform is not None: img = self.transform(img)
        # 如果刷新了标签,则标签ID更改为self.pid_dict[pid]
        if self.relabel: pid = self.pid_dict[pid]

        # 返回数据,用于进行训练或者测试
        return {
            "images": img,
            "targets": pid,
            "camid": camid,
            "img_path": img_path
        }

    @property
    def num_classes(self):
        return len(self.pids)

可以看到其实现过程还是很简单的,主要对融合的数据进行一个ID的更新,并且实现了__getitem__函数。

NaiveIdentitySampler

在def build_reid_train_loader(cfg):函数中,可以看到如下代码:

data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,cfg.SOLVER.IMS_PER_BATCH, num_instance)
data_sampler = samplers.BalancedIdentitySampler(train_set.img_items,cfg.SOLVER.IMS_PER_BATCH, num_instance)
data_sampler = samplers.TrainingSampler(len(train_set))

他们实现的过程都相差不大,这里就以NaiveIdentitySampler为例进行讲解:

class NaiveIdentitySampler(Sampler):
    """
    # 首先随机采集N个ID,然后每个ID选择K个实例图像
    Randomly sample N identities, then for each identity,
    randomly sample K instances, therefore batch size is N*K.
    Args:
    # 训练数据的列表,包含了所有训练的数据,也就是多个数据源
    - data_source (list): list of (img_path, pid, camid).
    # 在每个batch中,对每个ID采集num_instances图像
    - num_instances (int): number of instances per identity in a batch.
    - batch_size (int): number of examples in a batch.
    """

    def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None):
        # 包含了多个数据集的训练信息,如图片路径,身份ID,摄像头编号等等
        self.data_source = data_source
        self.batch_size = batch_size # 论文中的B
        # 对每个身份采集的图像数目(论文中的K)
        self.num_instances = num_instances
        # 通过计算获得每个batch需要采集多少个身份ID(论文中的P)
        self.num_pids_per_batch = batch_size // self.num_instances

        # 用于存储该图像 序列号-身份ID 保存于字典,方便查找转换
        self.index_pid = defaultdict(list)
        # 用于存储该图像 身份ID-摄像头编号 保存于字典,方便查找转换
        self.pid_cam = defaultdict(list)
        # 用于存储该图像 身份ID-对应图片所有的序列号 保存于字典,方便查找转换
        self.pid_index = defaultdict(list)

        # 循环把数据保存于上述的三个字典之中
        for index, info in enumerate(data_source):
            pid = info[1]
            camid = info[2]
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)

        # 把pid_index的键值(身份ID)保存于self.pids之中
        self.pids = list(self.pid_index.keys())
        # 计算共多少个身份ID
        self.num_identities = len(self.pids)

        # 设置随机种子
        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        # 获得_rank,_world_size(主机数目)用于分布式训练
        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()

    def __iter__(self):
        start = self._rank
        yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)

    def _infinite_indices(self):
        """
        每次迭代,根据配置生成一个batch_size=B(论文中的)*(论文中的)个indx
        """
        # 设定随机种子
        np.random.seed(self._seed)
        while True:
            # 获得有效的身份ID,这里的self.pids已经进行过更新,包含多个数据集的ID
            avai_pids = copy.deepcopy(self.pids)
            # 保存一个batch的身份ID的索引
            batch_idxs_dict = {}
            # 保存一个batch的indx,该indx主要传给CommDataset的__getitem__函数
            batch_indices = []
            # 如果有效的ID数目大于self.num_pids_per_batch
            while len(avai_pids) >= self.num_pids_per_batch:
                # 随机从avai_pids中选择self.num_pids_per_batch个身份id
                selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
                # 循环对每个身份ID进行处理
                for pid in selected_pids:
                    # Register pid in batch_idxs_dict if not,
                    # 如果pid这个ID在当前batch没有被采样过
                    if pid not in batch_idxs_dict:
                        # 获得pid这个身份ID对应所有图片的序列号
                        idxs = copy.deepcopy(self.pid_index[pid])
                        # 如果该身份ID图像的总数低于self.num_instances(论文中的K),则使用重复采样的方式
                        if len(idxs) < self.num_instances:
                            idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist()
                        # 随机进行采样
                        np.random.shuffle(idxs)
                        # 把采集到的pid身份对应的图像序列号添加到batch_idxs_dict中
                        batch_idxs_dict[pid] = idxs
                    # 如果该身份已经被采集过了(也就是selected_pids存在两个相同的ID),获得该ID对应所有图像的序列号
                    avai_idxs = batch_idxs_dict[pid]
                    # 重新导出num_instances个图像indx(序列号)信息,覆盖之前选到该ID对应的indx
                    for _ in range(self.num_instances):
                        batch_indices.append(avai_idxs.pop(0))
                    # 如果len(avai_idxs)小于self.num_instances则移除该ID,以及对应的图片
                    if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
                # 检测batch_indices是否合格
                assert len(batch_indices) == self.batch_size, "batch indices have wrong batch size"
                yield from batch_indices
                batch_indices = []

可以看到,这里的_infinite_indices函数,是一个核心,其主导采集数据最重要的过程,根据参数:

        # 对每个身份采集的图像数目(论文中的K)
        self.num_instances = num_instances
        # 通过计算获得每个batch需要采集多少个身份ID(论文中的P)
        self.num_pids_per_batch = batch_size // self.num_instances

合理的进行数据采集。

结语

现在我们依然还存在一个疑问,就是:

    for d in cfg.DATASETS.NAMES:
        # 根据数据集名称,创建对应数据集迭代的类,本人调试为类 fastreid.data.datasets.market1501.Market1501对象
        dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
        train_items.extend(dataset.train)

中的 dataset.train 是构建的。我会再下小节为大家进行讲解。

在这里插入图片描述

  • 6
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江南才尽,年少无知!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值