import tensorflow as tf from tensorflow import keras
class CategoricalPrecision(tf.keras.metrics.Metric): def __init__(self, categories=4, threshold=0.0, name='categorical_precision', **kwargs): super(CategoricalPrecision, self).__init__(name=name, **kwargs) self.categories = categories self.threshold = tf.constant([threshold], dtype='float32') self.positives = self.add_weight(name='positives', shape=(categories,), dtype=tf.dtypes.int32, initializer='zeros') self.predications = self.add_weight(name='predications', shape=(categories,), dtype=tf.dtypes.int32, initializer='zeros') def update_state(self, y_true, y_pred, sample_weight=None): preds = tf.cast(tf.math.argmax(y_pred, axis=1), tf.dtypes.int32) probs = tf.reduce_max(y_pred, axis=1) values = tf.cast(tf.wh