tf.cast: 映射
tf.cast(x, dtype, name=None)
x:输入
dtype:要映射的类型 如float32等
name: 自己给定,随意
例子:
# tensor `a` is [1.8, 2.2], dtype=tf.float
tf.cast(a, tf.int32) ==> [1, 2] # dtype=tf.int32
tf.argmax:输出最大的那个数值所在的索引号.输入向量,返回一个值,输入矩阵,返回一个向量,这个向量的每一个维度都是相对应矩阵行的最大值元素的索引号
tf.argmax(input, axis=None, name=None, dimension=None)
input:输入的tensor ,可以是float32,64 ,int32,64, 8,16 ,unit8,16,complex64,128 ,qint8,32 ,half 等类型
axis: 给定的tensor ,只能是 int32 ,int64, 0<=axis<=input .对于向量,axis=0
例子:
t1=[[1,2,3],[8,5,6],[7,8,1]]
t2=[2,5,7,0,3]
print(sess.run(tf.argmax(t1,1))) #axis=1
print(sess.run(tf.argmax(t2,0))) #axis=0
#output 对t1 [2,0,1] 即[1,2,3]中3最大,索引为2;[8,5,6]中8最大,索引为0; [7,8,1]中8最大,索引为1 ;最终就是[2,0,1]
对t2 2 即[2,5,7,0,3]中7最大,索引为2