tensorflow2.0 语义分割 评价指标代码(tensor版)

np编写的语义分割评价指标不能通过GPU 加速,故改编了一份tensor的代码,在tf2.0测试通过。第一次写博客,将就着看吧

def get_hist_t(predictions, labels, num_class):
# labels:[b.h,w] tensor
# predictions:[b.h,w,c] tensor
num_class = predictions.shape[3]
batch_size = predictions.shape[0]
hist = tf.zeros((num_class, num_class), dtype=‘int32’)

for i in range(batch_size):
    hist += fast_hist_t(tf.squeeze(tf.reshape(labels[i], (-1, 1)), 1),
                        tf.squeeze(tf.reshape(tf.argmax(predictions[i], axis=2), (-1, 1)), 1), num_class)
return hist

def per_class_acc_t(predictions, label_tensor):
labels = label_tensor
size = predictions.shape[0]
num_class = predictions.shape[3]
hist = tf.zeros((num_class, num_class), dtype=‘int32’)

#     num_class = predictions.shape[3]
#     batch_size = predictions.shape[0]
#     hist =tf.zeros((num_class, num_class),dtype='int32')
for i in range(size):
    hist += fast_hist_t(tf.squeeze(tf.reshape(labels[i], (-1, 1)), 1),
                        tf.squeeze(tf.reshape(tf.argmax(predictions[i], axis=2), (-1, 1)), 1), num_class)
#           hist += fast_hist(labels[i].flatten(), predictions[i].argmax(2).flatten(), num_class)

acc_total = tf.reduce_sum(tf.linalg.tensor_diag_part(hist)) / tf.reduce_sum(hist)
print('accuracy = %f' % np.nanmean(acc_total))
iu = tf.linalg.tensor_diag_part(hist) / (
            tf.reduce_sum(hist, 1) + tf.reduce_sum(hist, 0) - tf.linalg.tensor_diag_part(hist))
print('mean IU  = %f' % np.nanmean(iu))
for ii in range(num_class):
    if float(tf.reduce_sum(hist, 1)[ii]) == 0:
        acc = 0.0
    else:
        acc = tf.cast(tf.linalg.tensor_diag_part(hist)[ii], dtype='float32') / tf.cast(tf.reduce_sum(hist, 1)[ii],
                                                                                       dtype='float32')
    print("    class # %d accuracy = %f " % (ii, acc))

def fast_hist_t(a, b, n):#tensor版
k = (a >= 0) & (a < n)
d=tf.math.bincount(n * tf.cast(a[k],dtype=‘int32’) + tf.cast(b[k],dtype=‘int32’), minlength=n2)
# return np.bincount(n * a[k].astype(int) + b[k], minlength=n
2).reshape(n, n)
return tf.reshape(d,(n, n))

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值