import tensorflow as tf
cls_prob = tf.constant([[0.8, 0.9],[0.95, 0.9],[0.7, 0.9],[0.6, 0.9]], tf.float32)
pred = tf.argmax(cls_prob, axis=1)
label_int = tf.constant([-1, 0, 0, 1], tf.float32)
cond = tf.where(tf.greater_equal(label_int, 0)) # 正样本位置
picked = tf.squeeze(cond)
label_picked = tf.gather(label_int, picked)
pred_picked = tf.cast(tf.gather(pred, picked),tf.float32)
accuracy_op = tf.reduce_mean(tf.cast(tf.equal(label_picked, pred_picked), tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.greater_equal(label_int, 0)))
print(sess.run(cond))
print(sess.run(picked))
print(sess.run(label_picked))
print(sess.run(accuracy_op))
tensorflow 数据使用验证(验证集准确率计算)
最新推荐文章于 2023-07-28 14:18:50 发布