Keras中fit_generator 的多个分支输入时,需注意generator的格式 以及 输入序列的顺序

需要注意迭代器 yeild返回不能是[x1,x2],y 这样,而是要完整的字典格式的: yield ({'input_1': x1, 'input_2': x2}, {'output': y})  。

这也不算坑 追进去 fit_generator也能看到示例

def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag=True):
    ylen = len(y_train)
    loopcount = ylen // batch_size
    i=-1
    while True:
        if randomFlag:
            i = random.randint(0,loopcount-1)
        else:
            i=i+1
            i=i%loopcount

        yield ({'lstmInput': x_train[i*batch_size:(i+1)*batch_size], 
                'bgInput': x_train2[i*batch_size:(i+1)*batch_size]}, 
            {'prediction': y_train[i*batch_size:(i+1)*batch_size]})  

ps: 因为要是tuple yield后的括号不能省   
 

需注意的坑1是,validation data中如果用【】组成数组进行输入,是要按顺序的,按编译model前的设置model = Model(inputs=[simInput,lstmInput,bgInput], outputs=predictions),中数组的顺序来编译

需注意的坑2是,多输入input时,以后都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm结合时,会提示不匹配。

history=model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2),
            steps_per_epoch=len(trainX)//batchSize,
            validation_data=([testX,testX2],testY),
            epochs=epochs,
           callbacks=[tensorboard,checkpoint],initial_epoch=0,verbose=1)  # Fit the LSTM network/拟合LSTM网络
 

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值