class sampler(Sampler):
def __init__(self, train_size, batch_size):
self.num_data = train_size # 得到数据的数量
self.num_per_batch = int(train_size / batch_size) # 得到数据批次的标号
self.batch_size = batch_size #
self.range = torch.arange(0,batch_size).view(1, batch_size).long() # 得到0-batch_size的标号
self.leftover_flag = False
if train_size % batch_size: # 如果数据不能被整除
self.leftover = torch.arange(self.num_per_batch*batch_size, train_size).long() # 不能被整除部分的数据
self.leftover_flag = True # 保留最后不能整除的批次数据
def __iter__(self):
# 将数据批次的标号打乱,并和数据批次的长度相乘,得到随机排序后的照片起始标号
rand_num = torch.randperm(self.num_per_batch).view(-1,1) * self.batch_size
# 将起始位置扩充,得到具体每一张图片对应的标号
self.rand_num = rand_num.expand(self.num_per_batch, self.batch_size) + self.range
# 将所有标号打平铺开
self.rand_num_view = self.rand_num.view(-1)
if self.leftover_flag: # 如果保留最后的批次数据
# 将不能整除的数据与之前的数据合并
self.rand_num_view = torch.cat((self.rand_num_view, self.leftover),0)
# 对所有的标号进行迭代
return iter(self.rand_num_view)
def __len__(self):
return self.num_data
faster-rcnn中的sample函数注释解析
最新推荐文章于 2022-03-15 08:23:14 发布