深度学习中经常用到的一个技巧是使用批训练,这样的好处是可以减少显存的资源占用,对训练的结果也有一定的影响。
下面简单编写一个批量数据生成器:
import random
import numpy
mode = 0
x = np.arange(100)### 假设这个为features
y = np.arange(100) ########## 假设这个为labels
def batch_generator(data,shuffle,batch_size):
count = 0
if shuffle: ### 是否打乱
shuffle_index = random.shuffle(list(range(data[0].shape[0])),data[