yield生成器之keras模型训练代码实现

1 . 何为yield?

包含yield的函数,会生成一个生成器generator ,听过next函数可以不断生成一批数据.

示例:

>>> def test():
...     testData = ['a' , 'b' , 'c' , 'd' , 'e' ,'f']
...     for x in testData:
...         yield x
...         print ("test----"+x)
... 
>>> Generator = test()
>>> x = Generator.next()
>>> x
'a'
>>> x = Generator.next()
test----a
>>> x
'b'
>>> x = Generator.next()
test----b
>>> x
'c'

可以看出,当第一次调用next的时候,函数test返回第一个字母a,然后便停止运行,当第二次调用next的时候,函数从第一次中断的地方继续往下运行,于是便打印出test----a,接着继续循环,遇到yield返回b,以此循环下去,直到for循环结束.

 

2 . next方法

编写的生成器,需要使用next去调用程序,并获取返回的数据,要注意的是,python2使用next的方法是Generator.next().python3使用next的方法是next(Generator)

 

3 . 如何使用yield编写生成器

因为通过yield我们可以不断返回数据,并且可以保持上次调用next的记录点,那么我们可以通过一个while(1)循环,不断产生数据.生成的生成器可直接通过next循环调用,或者直接传入keras的model,使用model.fit_generator(generator=train_generator)方法训练模型.

代码格式如下:

###
#  imgLabelDict为词典dict , 里面存放数据格式如下{key(图片文件名) : value(图片对应的label)}
#  imgNameSet为所有图片的文件名集合
#  batch_size为每一个batch图片大小
#  img_num为图片总共的张数
###
def Generator(imgLabelDict , imgNameSet ,  batch_size , img_num):
    i = 0
    labelGenerator = []
    imgGenerator = []
    num = 0
    while 1:
        #当迭代完所有图片时,可对图片进行shuffle操作并新一轮迭代.
        if num == img_num or num == 0:
            #shuffle
            num = 0
            
        #get the img name
        img_name = imgNameSet[num]
        
        #labelProcess
        label = labels.get(img_name)
        labelGenerator.append(label)
        
        #imgProcess
        img = imread(img_name)
        img = preprocess_input(img)
        imgGenerator.append(img)

        if i>=batch_size-1:
            yield np.array(imgGenerator) ,np.array(labelGenerator)
            i = -1
            labelGenerator = []
            imgGenerator = []
        i+=1
        num += 1
        
        
trainGenerator = Generator(imgLabelDict , imgNameSet ,  batch_size , img_num)
trainBatchImg = trainGenerator.next()

 

存在问题:

因为使用到while(1)操作,所以会占用CPU资源比较高,具体原因可以参考这博客:https://blog.csdn.net/S1amDuncan/article/details/78840031

 

 

下面是使用Keras4bert框架实现对输入数据进行预测的整体代码: ```python import numpy as np from keras4bert.models import build_transformer_model from keras4bert.layers import MaskedSoftmax from keras4bert.tokenizers import Tokenizer from keras4bert.optimizers import AdamWarmup from keras4bert.callbacks import ModelCheckpoint # 1. 定义模型参数 maxlen = 128 num_classes = 2 hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 intermediate_size = 3072 hidden_act = 'gelu' dropout_rate = 0.1 attention_dropout_rate = 0.1 initializer_range = 0.02 learning_rate = 2e-5 weight_decay_rate = 0.01 num_warmup_steps = 10000 num_train_steps = 100000 batch_size = 32 # 2. 加载预训练模型 model = build_transformer_model( config_path='path/to/bert_config.json', checkpoint_path='path/to/bert_model.ckpt', model='bert', num_classes=num_classes, application='sequence_classification', mask_zero=True, with_pool=True, return_keras_model=False, name='BertModel' ) # 3. 构建训练数据生成器 tokenizer = Tokenizer('path/to/vocab.txt') def data_generator(data, batch_size): """数据生成器""" batch_token_ids, batch_segment_ids, batch_labels = [], [], [] while True: for token_ids, label in data: batch_token_ids.append(token_ids) batch_segment_ids.append([0] * maxlen) batch_labels.append(label) if len(batch_token_ids) == batch_size: batch_token_ids = tokenizer.pad_sequences(batch_token_ids, maxlen=maxlen) batch_segment_ids = tokenizer.pad_sequences(batch_segment_ids, maxlen=maxlen) batch_labels = np.array(batch_labels) yield [batch_token_ids, batch_segment_ids], batch_labels batch_token_ids, batch_segment_ids, batch_labels = [], [], [] # 4. 编译模型 model.compile( loss='sparse_categorical_crossentropy', optimizer=AdamWarmup( decay_steps=num_train_steps, warmup_steps=num_warmup_steps, lr=learning_rate, min_lr=1e-7, weight_decay_rate=weight_decay_rate, exclude_from_weight_decay=['Norm', 'bias'] ), metrics=['accuracy'] ) # 5. 开始训练 train_data = [(token_ids, label) for token_ids, label in zip(train_token_ids, train_labels)] valid_data = [(token_ids, label) for token_ids, label in zip(valid_token_ids, valid_labels)] train_steps = len(train_data) // batch_size valid_steps = len(valid_data) // batch_size callbacks = [ ModelCheckpoint( filepath='path/to/best_model.h5', save_weights_only=True, save_best_only=True, monitor='val_accuracy', mode='max', verbose=1 ) ] model.fit_generator( generator=data_generator(train_data, batch_size), steps_per_epoch=train_steps, epochs=num_train_steps, validation_data=data_generator(valid_data, batch_size), validation_steps=valid_steps, callbacks=callbacks ) # 6. 预测新数据 test_token_ids = tokenizer.texts_to_sequences(test_texts) test_token_ids = tokenizer.pad_sequences(test_token_ids, maxlen=maxlen) test_preds = model.predict([test_token_ids, np.zeros_like(test_token_ids)]) test_preds = np.argmax(test_preds, axis=1) ```
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序猿也可以很哲学

让我尝下打赏的味道吧

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值