def get_batches(x, y, n_batches=10): """ 这是一个生成器函数,按照n_batches的大小将数据划分了小块 """ batch_size = len(x)//n_batches for ii in range(0, n_batches*batch_size, batch_size): # 如果不是最后一个batch,那么这个batch中应该有batch_size个数据 if ii != (n_batches-1)*batch_size: X, Y = x[ii: ii+batch_size], y[ii: ii+batch_size] # 否则的话,那剩余的不够batch_size的数据都凑入到一个batch中 else: X, Y = x[ii:], y[ii:] # 生成器语法,返回X和Y yield X, Y