tensorflow中f-measure,Precision / Recall / F1 score 以及Confusion matrix的计算

使用tensorflow计算f-measure和召回率等内容的,需要安装一个sklearnwin7 64 下 你只需要输入 pip3 install sklearn 即可
下边是例子:来自于stackoverflow 感谢这个网站吧:原帖子https://stackoverflow.com/questions/35365007/tensorflow-precision-recall-f1-score-and-confusion-matrix


from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred)
pred = multilayer_perceptron(x, weights, biases)
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

    with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)
    for epoch in xrange(150):
            for i in xrange(total_batch):
                    train_step.run(feed_dict = {x: train_arrays, y: train_labels})
                    avg_cost += sess.run(cost, feed_dict={x: train_arrays, y: train_labels})/total_batch         
            if epoch % display_step == 0:
                    print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)

    #metrics
    y_p = tf.argmax(pred, 1)
    val_accuracy, y_pred = sess.run([accuracy, y_p], feed_dict={x:test_arrays, y:test_label})

    print("validation accuracy:", val_accuracy)
    y_true = np.argmax(test_label,1)
    print("Precision", sk.metrics.precision_score(y_true, y_pred))
    print( "Recall", sk.metrics.recall_score(y_true, y_pred))
    print( "f1_score", sk.metrics.f1_score(y_true, y_pred))
    print( "confusion_matrix")
    print( sk.metrics.confusion_matrix(y_true, y_pred))
    fpr, tpr, tresholds = sk.metrics.roc_curve(y_true, y_pred)

keras 1.2版本也有这些值

首先在compile加入这些参数

model.compile(loss='categorical_crossentropy',

              optimizer='adam',

              metrics=['accuracy', 'f1score', 'precision', 'recall'])

#然后用plt描绘出来  他们对应的result中存的key分别是:precision,val_precision, recall, val_recall ,acc, val_acc,#loss, val_loss

plt.figure()

fig = plt.gcf()
fig.set_size_inches(18.5, 10.5)
plt.plot(result.epoch,result.history['fmeasure'],label="fmeasure")
plt.plot(result.epoch,result.history['val_fmeasure'],label="val_fmeasure")
plt.scatter(result.epoch,result.history['fmeasure'],marker='*')
plt.scatter(result.epoch,result.history['val_fmeasure'])
plt.title('Fmeasure')
plt.ylabel('fmeasure')
plt.xlabel('epoch \ times')
plt.legend(loc='under right')
plt.show()
  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值