深度学习训练时,用相关的数据进行训练会让训练的数据局限于满足一小撮数据,缺少实用性,因此数据间相互独立才能训练出较好的模型。而更多的现实情况是数据之间或多或少存在着相关性,所以深度学习常常选择海量数据来弥补数据间相关产生的训练模型相关。在大批量数据学习中,使用小批量梯度学习是一种比较好的方式,每次选取一小部分数据进行梯参数更新,既能沿着较好的方向更新,又能兼顾训练的效率。另外,还有一种方法可以减弱数据相关性的影响,就是从数据集中随机抽取数据。
pytorch中引入了对数据进行切分分组的机制,下面通过代码说明pytorch如何将数据分成多个batch。
import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
BATCH_SIZE = 8 # 每个batch的大小,取5或者8
# 生成测试数据
x = torch.linspace(0, 9, 10<