"""
@File : samplers.py
@Time : 2021-05-09 22:35
@Author : XD
@Email : gudianpai@qq.com
@Software: PyCharm
"""
from __future__ import absolute_import
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
Args:
data_source (Dataset): dataset to sample from.
num_instances (int): number of instances per identity.
"""
def __init__(self, data_source, num_instances = 4):
self.data_source = data_source
self.num_instance = num_instances
self.index_dic = defaultdict(list)
for index, (_, pid,_) in enumerate(data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
self.num_indentities = len(self.pids)
def __iter__(self):
indics = torch.randperm(self.num_indentities)
ret = []
for i in indics:
pid = self.pids[i]
t = self.index_dic[pid]
replace = False if len(t) >= self.num_instance else True
t = np.random.choice(t, size = self.num_instance, replace = replace)
ret.extend(t)
return iter(ret)
def __len__(self):
return self.num_instance * self.num_indentities
if __name__ == '__main__':
import data_manager
dataset = data_manager.init_img_dataset(root = 'G:\data',name = 'market1501')
sampler = RandomIdentitySampler(dataset.train, num_instances = 4)
a = sampler.__iter__()
print(a.__next__())
print(a.__next__())
print(a.__next__())
=> Market1501 loaded
Dataset statistics:
------------------------------
subset |
------------------------------
train | 751 | 12936
query | 750 | 3368
gallery | 751 | 15913
------------------------------
total | 1501 | 32217
------------------------------
8560
8564
8565
Process finished with exit code 0