【Tensorflow】利用tensor计算多分类问题的Accuracy、Micro-F1、Macro-F1

具体代码分析参考文章
logits.shape = [nodes, node_class]
labels.shape = [nodes, node_class]
mask.shape = [1, nodes]

    def masked_accuracy_final(logits, labels, mask):
        """Accuracy with masking."""
        correct_prediction = tf.equal(
            tf.argmax(logits, 1), tf.argmax(labels, 1))
        accuracy_all = tf.cast(correct_prediction, tf.float32)
        mask_acc = tf.cast(mask, dtype=tf.float32)
        mask_acc /= tf.reduce_mean(mask_acc)
        accuracy_all *= mask_acc

        print("micro")
        mask_f1 = tf.cast(mask, dtype=tf.int64)
        f1_micro_labels = tf.add(tf.argmax(labels, 1), 1)*mask_f1  # mask节点以外的节点标签设为0
        f1_micro_predictions = tf.add(tf.argmax(logits, 1), 1)*mask_f1

        num_classes = logits.shape[1]+1  # 节点类型总数+1
        cm, op = _streaming_confusion_matrix(f1_micro_labels, f1_micro_predictions, num_classes)
        # with tf.Session() as sess:
        #     sess.run(tf.local_variables_initializer())  # 初始化所有的local variables
        #     print(sess.run(cm))
        #     print(sess.run(op))
        cm = op  # 令cm是更新后的混淆矩阵
        # cm = tf.contrib.metrics.confusion_matrix(f1_micro_predictions, f1_micro_labels, num_classes=None,
        #                                                        dtype=tf.int32, name=None, weights=None)
        pos_indices = [i for i in range(1, num_classes)]  # 排除类型为0的节点 再利用混淆矩阵计算f1-micro f-macro

        num_classes = cm.shape[0]
        neg_indices = [i for i in range(num_classes) if i not in pos_indices]
        cm_mask = np.ones([num_classes, num_classes])
        cm_mask[neg_indices, neg_indices] = 0  # 将负样本预测正确的位置清零零
        diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask))  # 正样本预测正确的数量

        cm_mask = np.ones([num_classes, num_classes])
        cm_mask[:, neg_indices] = 0  # 将负样本对应的列清零
        tot_pred = tf.reduce_sum(cm * cm_mask)  # 所有被预测为正的样本数量

        cm_mask = np.ones([num_classes, num_classes])
        cm_mask[neg_indices, :] = 0  # 将负样本对应的行清零
        tot_gold = tf.reduce_sum(cm * cm_mask)  # 所有正样本的数量

        numerator = diag_sum
        denominator = tot_pred
        numerator, denominator = tf.cast(numerator, dtype=tf.float64), tf.cast(denominator, dtype=tf.float64)
        zeros = tf.zeros_like(numerator, dtype=numerator.dtype)  # 创建全0Tensor
        denominator_is_zero = tf.equal(denominator, zeros)  # 判断denominator是否为零
        pr = tf.where(denominator_is_zero, zeros, numerator / denominator)  # 如果分母为0,则返回零

        numerator = diag_sum
        denominator = tot_gold
        numerator, denominator = tf.cast(numerator, dtype=tf.float64), tf.cast(denominator, dtype=tf.float64)
        zeros = tf.zeros_like(numerator, dtype=numerator.dtype)  # 创建全0Tensor
        denominator_is_zero = tf.equal(denominator, zeros)  # 判断denominator是否为零
        re = tf.where(denominator_is_zero, zeros, numerator / denominator)  # 如果分母为0,则返回零

        numerator = 2. * pr * re
        denominator = pr + re
        numerator, denominator = tf.cast(numerator, dtype=tf.float64), tf.cast(denominator, dtype=tf.float64)
        zeros = tf.zeros_like(numerator, dtype=numerator.dtype)  # 创建全0Tensor
        denominator_is_zero = tf.equal(denominator, zeros)  # 判断denominator是否为零
        f1_micro_final = tf.where(denominator_is_zero, zeros, numerator / denominator)  # 如果分母为0,则返回零

        print("macro")
        precisions, recalls, f1s, n_golds = [], [], [], []
        # 计算每个正例的precison,recall和f1
        for idx in pos_indices:
            num_classes = cm.shape[0]
            neg_indices = [i for i in range(num_classes) if i not in [idx]]
            cm_mask = np.ones([num_classes, num_classes])
            cm_mask[neg_indices, neg_indices] = 0  # 将负样本预测正确的位置清零零
            diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask))  # 正样本预测正确的数量

            cm_mask = np.ones([num_classes, num_classes])
            cm_mask[:, neg_indices] = 0  # 将负样本对应的列清零
            tot_pred = tf.reduce_sum(cm * cm_mask)  # 所有被预测为正的样本数量

            cm_mask = np.ones([num_classes, num_classes])
            cm_mask[neg_indices, :] = 0  # 将负样本对应的行清零
            tot_gold = tf.reduce_sum(cm * cm_mask)  # 所有正样本的数量

            numerator = diag_sum
            denominator = tot_pred
            numerator, denominator = tf.cast(numerator, dtype=tf.float64), tf.cast(denominator, dtype=tf.float64)
            zeros = tf.zeros_like(numerator, dtype=numerator.dtype)  # 创建全0Tensor
            denominator_is_zero = tf.equal(denominator, zeros)  # 判断denominator是否为零
            pr = tf.where(denominator_is_zero, zeros, numerator / denominator)  # 如果分母为0,则返回零

            numerator = diag_sum
            denominator = tot_gold
            numerator, denominator = tf.cast(numerator, dtype=tf.float64), tf.cast(denominator, dtype=tf.float64)
            zeros = tf.zeros_like(numerator, dtype=numerator.dtype)  # 创建全0Tensor
            denominator_is_zero = tf.equal(denominator, zeros)  # 判断denominator是否为零
            re = tf.where(denominator_is_zero, zeros, numerator / denominator)  # 如果分母为0,则返回零

            numerator = 2. * pr * re
            denominator = pr + re
            numerator, denominator = tf.cast(numerator, dtype=tf.float64), tf.cast(denominator, dtype=tf.float64)
            zeros = tf.zeros_like(numerator, dtype=numerator.dtype)  # 创建全0Tensor
            denominator_is_zero = tf.equal(denominator, zeros)  # 判断denominator是否为零
            f1 = tf.where(denominator_is_zero, zeros, numerator / denominator)  # 如果分母为0,则返回零

            precisions.append(pr)
            recalls.append(re)
            f1s.append(f1)
            cm_mask = np.zeros([num_classes, num_classes])
            cm_mask[idx, :] = 1
            n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask)))
        # pr = tf.reduce_mean(precisions)
        # re = tf.reduce_mean(recalls)
        f1_macro_final = tf.reduce_mean(f1s)
        return tf.reduce_mean(accuracy_all), f1_micro_final, f1_macro_final

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值