keras data generation, python生成器

Implement fit_generator( ) in Keras原文链接

Here is an example of fit_generator():

model.fit_generator(generator(), samples_per_epoch=50, nb_epoch=10)

Breaking it down:

generator() generates batches of samples indefinitely

sample_per_epoch number of samples you want to train in each epoch

nb_epoch number of epochs

As you can manually definesample_per_epoch andnb_epoch , you have to provide codes forgenerator . Here is an example:

Assume features is an array of data with shape (100,64,64,3) and labels is an array of data with shape (100,1). We use data from features and labels to train our model.

def generator(features, labels, batch_size):

 # Create empty arrays to contain batch of features and labels#

 batch_features = np.zeros((batch_size, 64, 64, 3))
 batch_labels = np.zeros((batch_size,1))

 while True:
   for i in range(batch_size):
     # choose random index in features
     index= random.choice(len(features),1)
     batch_features[i] = some_processing(features[index])
     batch_labels[i] = labels[index]
   yield batch_features, batch_labels


在python中,当你定义一个函数,使用了yield关键字时,这个函数就是一个生成器" (也就是说,只要有yield这个词出现,你在用def定义函数的时候,系统默认这就不是一个函数啦,而是一个生成器)。如果需要生成器返回(下一个)值,需要调用.next()函数。其实当系统判断def是生成器时,就会自动支持.next()函数,例如:

    def fib(max):  
        a, b = 1, 1  
        while a < max:  
            yield a  
            a, b = b, a+b  
      
    for n in fib(15):  
        print n  
      
    m = fib(13)  
    print m  
    print m.next()  
    print m.next()  
    print m.next()  



1. 每个生成器只能使用一次。比如上个例子中的m生成器,一旦打印完m的6个值,就没有办法再打印m的值了,因为已经吐完了。生成器每次运行之后都会在运行到yield的位置时候,保存暂时的状态,跳出生成器函数,在下次执行生成器函数的时候会从上次截断的位置继续开始执行循环。

2. yield一般都在def生成器定义中搭配一些循环语句使用,比如for或者while,以防止运行到生成器末尾跳出生成器函数,就不能再yield了。有时候,为了保证生成器函数永远也不会执行到函数末尾,会用while True: 语句,这样就会保证只要使用next(),这个生成器就会生成一个值,是处理无穷序列的常见方法。

拿上面那个为例, 每次继续开始执行上次没处理完成的位置,但后面的每次循环都只在while True这个循环体内部运行,之前的非循环体batch_feature...  batch_label ...并没有执行,因为它们只在第一次进入生成其函数的时候才有效地运行过一次。


With the generator above, if we definebatch_size = 10 , that means it will randomly taking out 10 samples fromfeatures and labels to feed into each epoch until an epoch hits 50 sample limit. Then fit_generator() destroys the used data and move on repeating the same process in new epoch.

One great advantage aboutfit_generator() besides saving memory is user can integrate random augmentation inside the generator, so it will always provide model with new data to train on the fly.


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值