我想对具有多个输入的Keras模型执行交叉验证。所以,我试着KerasClassifier。对于只有一个输入的正常序列模型,这种方法可以很好地工作。然而,当使用函数api并扩展到两个输入时,sklearn的cross_val_predict似乎并没有像预期的那样工作。你知道吗def create_model():
input_text = Input(shape=(1,), dtype=tf.string)
embedding = Lambda(UniversalEmbedding, output_shape=(512, ))(input_text)
dense = Dense(256, activation='relu')(embedding)
input_title = Input(shape=(1,), dtype=tf.string)
embedding_title = Lambda(UniversalEmbedding, output_shape=(512, ))(input_title)
dense_title = Dense(256, activation='relu')(embedding_title)
out = Concatenate()([dense, dense_title])
pred = Dense(2, activation='softmax')(out)
model = Model(inputs=[input_text, input_title], outputs=pred)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
失败的交叉验证代码keras_classifier = KerasClassifier(build_fn=create_model, epochs=10, batch_size=10, verbose=1)
cv = StratifiedKFold(n_splits=10, random_state=0)
results = cross_val_predict(keras_classifier, [X1, X2], y, cv=cv, method = "predict_proba")
后来,我发现KerasClassifier只支持序列模型:https://keras.io/scikit-learn-api/。换句话说,它不支持具有多个输入的函数api。你知道吗
因此,我想知道是否有其他方法可以对keras中使用函数api的模型执行交叉验证。{}具体地说,在这个测试中,每个片段的交叉预测的概率是多少。你知道吗
如果需要,我很乐意提供更多细节。你知道吗
编辑:我现在的问题是如何将多个输入输入到StratifiedKFold.split()。我已经把????????????放在代码里了。只是在想是否有可能把它命名为[input1, input2, input3, input4, input5]
假设我有5个输入,如input1、input2、input3、input4、input5,我如何在StratifiedKFold.split()中使用这些输入k_fold = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
for train_index, test_index in k_fold.split(????????????, labels):
print("iteration", i, ":")
print("train indices:", train_index)
#input1
print("train data:", input1[train_index])
#input2
print("train data:", input2[train_index])
#input3
print("train data:", input3[train_index])
#input4
print("train data:", input1[train_index])
#input5
print("train data:", input1[train_index])
print("test indices:", test_index)
print("test data:", X[test_index])