faster-rcnn中的sample函数注释解析

class sampler(Sampler):
  def __init__(self, train_size, batch_size):
    self.num_data = train_size # 得到数据的数量
    self.num_per_batch = int(train_size / batch_size) # 得到数据批次的标号
    self.batch_size = batch_size #
    self.range = torch.arange(0,batch_size).view(1, batch_size).long() # 得到0-batch_size的标号
    self.leftover_flag = False
    if train_size % batch_size: # 如果数据不能被整除
      self.leftover = torch.arange(self.num_per_batch*batch_size, train_size).long() # 不能被整除部分的数据
      self.leftover_flag = True # 保留最后不能整除的批次数据

  def __iter__(self):
    # 将数据批次的标号打乱,并和数据批次的长度相乘,得到随机排序后的照片起始标号
    rand_num = torch.randperm(self.num_per_batch).view(-1,1) * self.batch_size
    # 将起始位置扩充,得到具体每一张图片对应的标号
    self.rand_num = rand_num.expand(self.num_per_batch, self.batch_size) + self.range
    # 将所有标号打平铺开
    self.rand_num_view = self.rand_num.view(-1)

    if self.leftover_flag: # 如果保留最后的批次数据
      # 将不能整除的数据与之前的数据合并
      self.rand_num_view = torch.cat((self.rand_num_view, self.leftover),0)
    # 对所有的标号进行迭代
    return iter(self.rand_num_view)

  def __len__(self):
    return self.num_data
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值