tensorflow的feed_dict原理,它需要的每次生成一个batch的数据。
1. Dataset类
将数据处理部分写成一个类,init函数中定义了一些参数
class DataSet(object):
def __init__(self,
images,
labels,.....)
self._images = images
self._labels = labels
self._epochs_completed = 0 # 已经经过了多少个epoch
self._index_in_epoch = 0 # 在一个epoch中的index
self._num_examples #是指训练数据的样本总个数
2.next_batch函数
如何保证每一次调用next_batch函数还能记住上一次的位置呢?tensorflow源码是将dataset输入写为一个类,self._index_in_epoch就相当于一个类变量,记住上一次的位置。
以下函数主要分为三部分,
第一个epoch怎么处理,
每个epoch的结尾连接下一个epoch的开头怎么处理,
非第一个epoch&非结尾怎么处理。
这样分开,主要是因为每个epoch的开头,都要shuffle index.即将所有数据顺序都打乱
def next_batch(self, batch_size, fake_data=False, shuffle=True):
start = self._index_in_epoch #self._index_in_epoch 所有的调用,总共用了多少个样本,相当于一个全局变量 #start第一个batch为0,剩下的就和self._index_in_epoch一样,如果超过了一个epoch,在下面还会重新赋值。
# Shuffle for the first epoch 第一个epoch需要shuffle
if self._epochs_completed == 0 and start == 0 and shuffle:
perm0 = numpy.arange(self._num_examples) #生成的一个所有样本长度的np.array
numpy.random.shuffle(perm0)
self._images = self.images[perm0]
self._labels = self.labels[perm0]
# Go to the next epoch
if start + batch_size > self._num_examples: #epoch的结尾和下一个epoch的开头
# Finished epoch
self._epochs_completed += 1
# Get the rest examples in this epoch
rest_num_examples = self._num_examples - start # 最后不够一个batch还剩下几个
images_rest_part = self._images[start:self._num_examples]
labels_rest_part = self._labels[start:self._num_examples]
# Shuffle the data
if shuffle:
perm = numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._images = self.images[perm]
self._labels = self.labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size - rest_num_examples
end = self._index_in_epoch
images_new_part = self._images[start:end]
labels_new_part = self._labels[start:end]
return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
else: # 除了第一个epoch,以及每个epoch的开头,剩下中间batch的处理方式
self._index_in_epoch += batch_size # start = index_in_epoch
end = self._index_in_epoch #end很简单,就是 index_in_epoch加上batch_size
return self._images[start:end], self._labels[start:end] #在数据x,y