- tf.cast()
tf.cast()的作用是将一个张量的类型改变为另外一个类型,如第11行,将浮点型转化为整数型
def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
The operation casts `x` (in case of `Tensor`) or `x.values`
(in case of `SparseTensor`) to `dtype`.
For example:
x = tf.constant([1.8, 2.2], dtype=tf.float32)
tf.cast(x, tf.int32) # [1, 2], dtype=tf.int32
- tf.argmax()
def argmax(input,
axis=None,
name=None,
dimension=None,
output_type=dtypes.int64):
tf.argmax()的作用是选择某个维度中,最大值所对应的下标(索引);当axis=1时,表示在每一行选取最大值对应的下标。
应该axis为[-1,1),不能取1,-1为行,0为列。
例如:
a=tf.constant([1,2,3,4,5,0],dtype=tf.int32)
tf.argmax(a,axis=-1) # 4
作者:亡城
来源:CSDN
原文:https://blog.csdn.net/The_lastest/article/details/81050778
版权声明:本文为博主原创文章,转载请附上博文链接!