【python】小批次数据生成器

机器学习中的小批次数据生成器,上代码以备以后使用:

#inputs, targets就是 X 和 y
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
        np.random.shuffle(indices)
    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt]

用法:

>>>X = np.random.randn(6,4)
>>>y = np.random.randint(0,2,6)
>>>gen = iterate_minibatches(X, y, 10, shuffle=False)
>>>next(gen)
Out[113]: 
(array([[-1.29894375,  1.64519902,  0.449616  ,  0.2820526 ],
        [ 0.49223679,  0.4552031 ,  0.40525815,  0.22591931]]), array([0, 0]))

>>>next(gen)
Out[114]: 
(array([[ 1.46478028,  0.26404314,  0.79638976,  0.223665  ],
        [ 0.08203239,  1.52025888,  1.2913706 ,  0.38070632]]), array([0, 0]))

>>>next(gen)
Out[115]: 
(array([[-0.53161078, -0.3045918 ,  0.63480717, -0.12489774],
        [-0.41806273,  0.09459128, -0.4725088 , -0.83306486]]), array([1, 0]))

>>>next(gen)
Traceback (most recent call last):

  File "<ipython-input-116-b2c61ce5e131>", line 1, in <module>
    gen.next()

StopIteration

意思就是把所有数据都迭代完后,迭代器就作废了,如果想要实现循环迭代,得再加一句while True。
循环迭代版:

def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    while True:
        if shuffle:
            indices = np.arange(len(inputs))
            np.random.shuffle(indices)
        for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
            if shuffle:
                excerpt = indices[start_idx:start_idx + batchsize]
            else:
                excerpt = slice(start_idx, start_idx + batchsize)
            yield inputs[excerpt], targets[excerpt]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

学渣渣渣渣渣

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值