tensorflow tf.argmax() 用法 例子

转自:https://blog.csdn.net/Jiaach/article/details/78874704

argmax()官方文档如下:

tf.argmax(input, dimension, name=None) 
Returns the index with the largest value across dimensions of a tensor. 
Args: 
input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, int16, int8, complex64, qint8, quint8, qint32. 
dimension: A Tensor of type int32. int32, 0 <= dimension < rank(input). Describes which dimension of the input Tensor to reduce across. For vectors, use dimension = 0. 
name: A name for the operation (optional). 
Returns: 
A Tensor of type int64.

dimension=0 按列找 
dimension=1 按行找 
tf.argmax()返回最大数值的下标 
通常和tf.equal()一起使用,计算模型准确度

correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

栗子

>>> import tensorflow as tf
>>> a = tf.constant([1.,2.,3.,0.,9.,])
>>> b = tf.constant([[1,2,3],[3,2,1],[4,5,6],[6,5,4]])
>>> with tf.Session() as sess:
...     sess.run(tf.argmax(a, 0))
Output:
4
>>> with tf.Session() as sess:
...     sess.run(tf.argmax(b, 0))
Output:
array([3, 2, 2])
>>> with tf.Session() as sess:
...     sess.run(tf.argmax(b, 1))
Output:
array([2, 0, 2, 0])

Ref: 
API文档

--------------------- 本文来自 Jaichg 的CSDN 博客 ,全文地址请点击:https://blog.csdn.net/Jiaach/article/details/78874704?utm_source=copy 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值