tensorflow(6) mnist.train.next_batch()函数解析

tensorflow的feed_dict原理,它需要的每次生成一个batch的数据。

1. Dataset类

将数据处理部分写成一个类,init函数中定义了一些参数

class DataSet(object):

  def __init__(self,
               images,
               labels,.....)
    self._images = images
    self._labels = labels
    self._epochs_completed = 0 # 已经经过了多少个epoch
    self._index_in_epoch = 0 # 在一个epoch中的index
    self._num_examples #是指训练数据的样本总个数
2.next_batch函数

如何保证每一次调用next_batch函数还能记住上一次的位置呢?tensorflow源码是将dataset输入写为一个类,self._index_in_epoch就相当于一个类变量,记住上一次的位置。
以下函数主要分为三部分,
第一个epoch怎么处理,
每个epoch的结尾连接下一个epoch的开头怎么处理,
非第一个epoch&非结尾怎么处理。
这样分开,主要是因为每个epoch的开头,都要shuffle index.即将所有数据顺序都打乱

def next_batch(self, batch_size, fake_data=False, shuffle=True):
    start = self._index_in_epoch  #self._index_in_epoch  所有的调用,总共用了多少个样本,相当于一个全局变量 #start第一个batch为0,剩下的就和self._index_in_epoch一样,如果超过了一个epoch,在下面还会重新赋值。
    # Shuffle for the first epoch 第一个epoch需要shuffle
    if self._epochs_completed == 0 and start == 0 and shuffle:
      perm0 = numpy.arange(self._num_examples)  #生成的一个所有样本长度的np.array
      numpy.random.shuffle(perm0)
      self._images = self.images[perm0]
      self._labels = self.labels[perm0]
    # Go to the next epoch


    if start + batch_size > self._num_examples: #epoch的结尾和下一个epoch的开头
      # Finished epoch
      self._epochs_completed += 1
      # Get the rest examples in this epoch
      rest_num_examples = self._num_examples - start  # 最后不够一个batch还剩下几个
      images_rest_part = self._images[start:self._num_examples]
      labels_rest_part = self._labels[start:self._num_examples]
      # Shuffle the data
      if shuffle: 
        perm = numpy.arange(self._num_examples)
        numpy.random.shuffle(perm)
        self._images = self.images[perm]
        self._labels = self.labels[perm]
      # Start next epoch
      start = 0
      self._index_in_epoch = batch_size - rest_num_examples
      end = self._index_in_epoch
      images_new_part = self._images[start:end] 
      labels_new_part = self._labels[start:end]
      return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)    
    else:  # 除了第一个epoch,以及每个epoch的开头,剩下中间batch的处理方式
      self._index_in_epoch += batch_size # start = index_in_epoch
      end = self._index_in_epoch #end很简单,就是 index_in_epoch加上batch_size 
      return self._images[start:end], self._labels[start:end] #在数据x,y
  • 16
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值