创建一个对象,实现__call__方法
class weighted_cross_entropy(object):
def __call__(self, y_pred, y_true):
"""
logits: a Tensor with shape [batch_size, image_width, image_height, channel], score from the unet conv10
label: a Tensor with shape [batch_size, image_width, image_height], ground truth
"""
weight = [0.21008659, 0.26289699, 0.28279202, 0.24422441]
# label = tf.one_hot(tf.cast(y_tru