tf.cast:用于改变某个张量的数据类型
例如:
import tensorflow as tf;
import numpy as np;
A = tf.convert_to_tensor(np.array([[1,1,2,4], [3,4,8,5]]))
with tf.Session() as sess:
print A.dtype
b = tf.cast(A, tf.float32)
print b.dtype
输出:
<dtype: 'int64'>
<dtype: 'float32'>
开始的时候定义A没有给出类型,采用默认类型,整形。利用tf.cast函数就改为float类型