【科研小小白的神经网络Day2】什么是batch生成器,为什么又要存在一个生成器,为什么需要分批处理数据?

背景

深度学习的训练数据往往很多,如果一次性训练所有的数据,不但会导致时间过长,而且训练次数不够,参数也不能得到很好的迭代。为此,将训练数据分成小的batch,一次batch迭代就可以完成一次参数更新,大大提高了训练速度。
Pytorch中有现成的batch生成器,但是为了底层原理的理解,最好自己能够写出这样的代码,就先从能看懂现成代码开始吧。

batch生成器函数

def data_iter(batch_size,features,labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        j = torch.LongTensor(indices[i:min(i+batch_size,num_examples)])
        yield features.index_select(0,j),labels.index_select(0,j)

这个函数名为batch_iter,参数有三:batch_size(batch的大小)、features(训练数据的特征,可以视为自变量)和labels(训练数据的标签,可以视为因变量)。
首先,num_examples获得features变量的长度,这个值就是训练数据的个数;
接着,利用range函数生成从0到training number-1(num_examples-1)的range,利用list函数将其转为列表,这样,indices就是一个包含从0到num_examples-1的list了;
然后利用random包的shuffle函数对indices的数进行洗牌(实际上就是打乱,然后随机排列);
接下来,range(0, num_examples, batch_size)是从0到num_examples-1,每隔batch_size步长产生一个数,即这里的i分别为0,10,20,…,990;
接着,indices[i:min(i+batch_size,num_examples)]是一个索引操作,表示取出indices这个list里从i到下一个变量值-1的子list,其中min(i+batch_size,num_examples)的作用是防止索引超出最大范围;这样,j就得到了一个值类型为long的tensor,值是indices里索引出来的子list;
最后,yield函数是一个生成器,在这里可以简单的看作是return;features.index_select(0,j)中,0表示的是dim,这个操作从features中索引出了所有行数为j的features,组成一个tensor;labels同理。
这样,调用这个函数之后就可以获得一个训练数据的batch了。

实例演示

import torch
import numpy as np
import random

num_inputs = 2
num_examples = 1000
true_w = [2,-3.4]
true_b = 4.2
features = torch.from_numpy(np.random.normal(0,1,(num_examples,num_inputs)))
labels = true_w[0]*features[:,0]+true_w[1]*features[:,1]+true_b
labels += torch.from_numpy(np.random.normal(0,0.01,size = labels.size()))

def data_iter(batch_size,features,labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        j = torch.LongTensor(indices[i:min(i+batch_size,num_examples)])
        yield features.index_select(0,j),labels.index_select(0,j)

batch_size = 10

for x,y in data_iter(batch_size,features,labels):
    print(x,y) #这里只演示出第一个batch
    break
输出结果:

tensor([[ 1.0289, -0.5676],
        [ 0.4811,  0.0651],
        [-0.7113, -0.7735],
        [ 0.5077,  1.5935],
        [ 0.5343,  0.8802],
        [-1.1659, -1.0234],
        [ 0.1249, -0.2690],
        [-1.9804,  0.9771],
        [-0.5953, -0.0802],
        [ 0.2558, -1.0796]], dtype=torch.float64) tensor([ 8.2047,  4.9386,  5.4247, -0.2031,  2.2704,  5.3475,  5.3715, -3.0973,
         3.2829,  8.3943], dtype=torch.float64)

————————————————
版权声明:本文为CSDN博主「Bellamy_xxx」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_44992157/article/details/127410119

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值