多输入模型多适用于问答模型或者对于时间序列模型来说有部分特征是针对样本个体而固定的,不随时间变换而发生改变的情况下。
对于模型的输入数据格式来说,有很多种方式,普通的全部数据导入,或者写成生成器等,可以逐批读取数据然后训练模型,但是当你使用tensorflow内置分布式训练,也就是多机多卡模卡MultiWorkerMirroredStrategy的时候,就必须使用Dataset格式。
因为Dataset会自动根据batch_size分发数据进行迭代训练。
如果对MultiWorkerMirroredStrategy以及MirroredStrateg两种训练模型感兴趣的可以去看看我的另外一篇文章。https://blog.csdn.net/qq_35869630/article/details/106313745
一、多输入模型
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense, concatenate,Bidirectional
from tensorflow.keras import Input,Model
def build_and_compile_model():
inputA = Input(shape=(7, 22))
inputB = Input(shape=(3,))
x = LSTM(128, return_sequences=False, activation="relu")(inputA)
x = Dense(128, activation="relu")(x)
x = Dense(64, activation="relu")(x)
x = Model(inputs=inputA, outputs=x)
y = Dense(32, activation="relu")(inputB)
y = Dense(16, activation="relu")(y)
y = Model(inputs=inputB, outputs=y)
combined = concatenate([x.output, y.output], axis=-1)
z = Dense(32, activation="relu")(combined)
z = Dense(1, activation="sigmoid")(z)
model = Model(inputs=[x.input, y.input], outputs=z)
model.compile(loss='binary_crossentropy', optimizer='rmsprop',