【ReID】【代码注释】采样器 deep-person-reid/samplers.py

源码URL:
https://github.com/michuanhaohao/deep-person-reid/blob/master/samplers.py

采样器读源码注释如下

from __future__ import absolute_import
from collections import defaultdict
import numpy as np

import torch
from torch.utils.data.sampler import Sampler

class RandomIdentitySampler(Sampler):
    """
    Randomly sample N identities, then for each identity,
    randomly sample K instances, therefore batch size is N*K.

    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.

    Args:
        data_source (Dataset): dataset to sample from.
        num_instances (int): number of instances per identity.
    """

    def __init__(self, data_source, num_instances=4):
        self.data_source = data_source  # 先把传入参数放到类中
        self.num_instances = num_instances
        self.index_dic = defaultdict(list)  # 一种自动生长的字典,不限定长度,存放多少就生成多少个key:value
        for index, (_, pid, _) in enumerate(data_source):  # 取index和pid
            self.index_dic[pid].append(index)  # 放到dictionary中

        self.pids = list(self.index_dic.keys())  # 取keys建list
        self.num_identities = len(self.pids)  # id的数量


    def __iter__(self):  # 如果分类,返回的是0-9999的一个List(0表示第0张图)
        # 所有图片中挑了3004张图片,序号都是安排好的
        indices = torch.randperm(self.num_identities)  # shuffle
        ret = []  # 每个t新建一个list,  每相邻4张,表示一个id,采样93个batch_size
        for i in indices:
            pid = self.pids[i]
            t = self.index_dic[pid]
            replace = False if len(t) >= self.num_instances else True
            t = np.random.choice(t, size=self.num_instances, replace=replace)  # replace=False时shuffle不重复,
                                                                           # 但若取得图少于4会报错,所以设置上面的判断
            ret.extend(t)
#        from IPython import embed
#        embed()
        return ret

    def __len__(self):
        raise self.num_instances * self.num_identities  # 返回长度



#if __name__ == '__main__':
#    from util.data_manager import Market1501
#    dataset = Market1501(root='F:/Market-1501')
#    sampler = RandomIdentitySampler(dataset.train, num_instances=4)
#    a = sampler.__iter__()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

锥栗

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值