在使用tf.nn.weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight)时,如何传递参数logits?
如果网络的输出是preds,那么不能直接将preds传递给logits,而应该使logits=K.log(preds/(1-preds)),其中K来自于import tensorflow.keras.backend as K
其原因在于该函数的公式表达式为:labels * -log(sigmoid(logits)) * pos_weight + (1 - labels) * -log(1 - sigmoid(logits))
而非交叉熵公式:labels * -log(preds) * pos_weight + (1 - labels) * -log(1 - preds)
根据https://www.jianshu.com/p/31c7fe00d9de的推导,preds = sigmoid(log(preds/(1-preds)))
因此,在该函数的logits参数传递时,需要传入logits=K.log(preds/(1-preds))才能得到交叉熵的公式:labels * -log(preds) * pos_weight + (1 - labels) * -log(1 - preds)