tf.cast()函数讲解
参数:tf.cast(tensor,dtype=?)
作用:将x的数据格式转化为dtype的数据类型。
示例代码:
a=[1,2,0,0,1]
s=tf.math.equal(a,0)
print(s)
d=tf.cast(s,dtype=tf.float32)
print(d)
输出:
tf.Tensor([False False True True False], shape=(5,), dtype=bool)
tf.Tensor([0. 0. 1. 1. 0.], shape=(5,), dtype=float32)