Tensorflow之batch的解释,采用yield方法解释

  • 再多文字的解释都不如代码来的简洁
  • 看完之后再想一下在神经网络训练的时候引入bacth是多么明智
#本案例其实也是解释了为什么在模型训练时采用batch的方法会更加有效率
#在训练数据十分庞大时,如果只是简单的将数据全部轮训一遍做法很低效,把数据切分会变得有效率
import numpy as np 
def get_batch(x,y,batch):
    n_samples = len(x)
    print("n_samples:",n_samples)
    #n_samples=10,for i in range(3,10,3) 
    #i的值分别是3,6,9,这样实际上只会取到数组[0-9]第10个取不到的
    for i in range(batch,n_samples,batch):
        print("i:",i,"batch:",batch)
        yield x[i-batch:i],y[i-batch:i]
#yield用在函数中,把这函数封装成一个generator(生成器),在调用for i in fun(param)起作用
ma = np.array([[0,1],[1,2],[2,3],[3,4],[4,3],[5,5],[6,2],[7,4],[8,3],[9,5]])
#ma.shape(10,2)
print("ma:",ma[0:3])
#[[0 1][1 2][2 3]]
mb = np.array([0,1,2,3,4,5,6,7,8,9])
#mb.shape(10,)
for j in range(3):  
    for tx,ty in get_batch(ma,mb,3):
        print("tx:",tx,"ty:",ty)
        print("over.")
print("Finished.",tx,ty)
Output:
ma: [[0 1] [1 2] [2 3]]
n_samples: 10
i: 3 batch: 3
tx: [[0 1] [1 2] [2 3]] ty: [0 1 2]
over.
i: 6 batch: 3
tx: [[3 4] [4 3] [5 5]] ty: [3 4 5]
over.
i: 9 batch: 3
tx: [[6 2] [7 4] [8 3]] ty: [6 7 8]
over.  .....循环3次
Finished. [[6 2] [7 4] [8 3]] [6 7 8]

 

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值