python交叉验证函数_如何在python的keras函数api中执行交叉验证

在Keras中,使用函数API构建了一个具有两个输入的深度学习模型。遇到问题在于,尝试使用sklearn的StratifiedKFold进行交叉验证时,KerasClassifier不支持多输入模型。寻求解决方案,以便能正确地对多个输入执行交叉验证。
摘要由CSDN通过智能技术生成

我想对具有多个输入的Keras模型执行交叉验证。所以,我试着KerasClassifier。对于只有一个输入的正常序列模型,这种方法可以很好地工作。然而,当使用函数api并扩展到两个输入时,sklearn的cross_val_predict似乎并没有像预期的那样工作。你知道吗def create_model():

input_text = Input(shape=(1,), dtype=tf.string)

embedding = Lambda(UniversalEmbedding, output_shape=(512, ))(input_text)

dense = Dense(256, activation='relu')(embedding)

input_title = Input(shape=(1,), dtype=tf.string)

embedding_title = Lambda(UniversalEmbedding, output_shape=(512, ))(input_title)

dense_title = Dense(256, activation='relu')(embedding_title)

out = Concatenate()([dense, dense_title])

pred = Dense(2, activation='softmax')(out)

model = Model(inputs=[input_text, input_title], outputs=pred)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

return model

失败的交叉验证代码keras_classifier = KerasClassifier(build_fn=create_model, epochs=10, batch_size=10, verbose=1)

cv = StratifiedKFold(n_splits=10, random_state=0)

results = cross_val_predict(keras_classifier, [X1, X2], y, cv=cv, method = "predict_proba")

后来,我发现KerasClassifier只支持序列模型:https://keras.io/scikit-learn-api/。换句话说,它不支持具有多个输入的函数api。你知道吗

因此,我想知道是否有其他方法可以对keras中使用函数api的模型执行交叉验证。{}具体地说,在这个测试中,每个片段的交叉预测的概率是多少。你知道吗

如果需要,我很乐意提供更多细节。你知道吗

编辑:我现在的问题是如何将多个输入输入到StratifiedKFold.split()。我已经把????????????放在代码里了。只是在想是否有可能把它命名为[input1, input2, input3, input4, input5]

假设我有5个输入,如input1、input2、input3、input4、input5,我如何在StratifiedKFold.split()中使用这些输入k_fold = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)

for train_index, test_index in k_fold.split(????????????, labels):

print("iteration", i, ":")

print("train indices:", train_index)

#input1

print("train data:", input1[train_index])

#input2

print("train data:", input2[train_index])

#input3

print("train data:", input3[train_index])

#input4

print("train data:", input1[train_index])

#input5

print("train data:", input1[train_index])

print("test indices:", test_index)

print("test data:", X[test_index])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值