tf.cast()数据类型转换
代码 dtype=uint8转dtype=float32
print(x_)
x_float = tf.cast(x_ tf.float32)
print(x_float)
Tensor("preprocess/Reshape:0", shape=(?, 784), dtype=uint8)
Tensor("preprocess/Cast:0", shape=(?, 784), dtype=float32)
解析 x_float = tf.cast(x_, tf.float32)
x_ 张量,转换前数据
tf.float32 转换后数据类型
np.prod() 所有元素相乘
代码
np.prod([1, 2, 3, 4 ])
24