源码URL:
https://github.com/michuanhaohao/deep-person-reid/blob/master/samplers.py
采样器读源码注释如下
from __future__ import absolute_import
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
Args:
data_source (Dataset): dataset to sample from.
num_instances (int): number of instances per identity.
"""
def __init__(self, data_source, num_instances=4):
self.data_source = data_source # 先把传入参数放到类中
self.num_instances = num_instances
self.index_dic = defaultdict(list) # 一种自动生长的字典,不限定长度,存放多少就生成多少个key:value
for index, (_, pid, _) in enumerate(data_source): # 取index和pid
self.index_dic[pid].append(index) # 放到dictionary中
self.pids = list(self.index_dic.keys()) # 取keys建list
self.num_identities = len(self.pids) # id的数量
def __iter__(self): # 如果分类,返回的是0-9999的一个List(0表示第0张图)
# 所有图片中挑了3004张图片,序号都是安排好的
indices = torch.randperm(self.num_identities) # shuffle
ret = [] # 每个t新建一个list, 每相邻4张,表示一个id,采样93个batch_size
for i in indices:
pid = self.pids[i]
t = self.index_dic[pid]
replace = False if len(t) >= self.num_instances else True
t = np.random.choice(t, size=self.num_instances, replace=replace) # replace=False时shuffle不重复,
# 但若取得图少于4会报错,所以设置上面的判断
ret.extend(t)
# from IPython import embed
# embed()
return ret
def __len__(self):
raise self.num_instances * self.num_identities # 返回长度
#if __name__ == '__main__':
# from util.data_manager import Market1501
# dataset = Market1501(root='F:/Market-1501')
# sampler = RandomIdentitySampler(dataset.train, num_instances=4)
# a = sampler.__iter__()