Sampler.py 设置每个batch是N=p*k的采样
from __future__ import absolute_import
from collections import defaultdict
import numpy as np
from .data_manager import Market1501
import torch
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
随机采样n个id,对于每一个id
随机采样 k个实例,所以 batch size的大小为n*k
Args:
data_source (Dataset): dataset to sample from.
num_instances (int): number of instances per identity.
"""
def __init__(self, data_source, num_instances):
self.data_source = data_source
self.num_instances = num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.key())
self.num_identities = len(self.pids)
from IPython import embed
embed()
def __iter__(self):
indices = torch.randperm(self.num_identities)
ret = []
for i in indices:
pid = self.pids[i]
t = self.index_dic[pid]
if len(t) < self.num_instances:
replace = True
else:
replace = False
t = np.random.choice(t, size=self.num_instances, replace=replace)
ret.extend(t)
return ret
def __len__(self):
return self.num_instances * self.num_instances
if __name__ == '__main__':
dataset = Market1501(root='../../data')
sampler = RandomIdentitySampler(dataset.train,num_instances=4)
a = sampler.__init__()