tf.cast(x, dtype, name=None)
此函数是类型转换函数
参数
- x:输入
- dtype:转换目标类型
- name:名称
tf.cast()一般用来转换数据类型,下面一个例子将数据类型转换成bool类型
import tensorflow as tf
a = tf.Variable([1,0,0,1,0])
b = tf.cast(a,dtype = bool)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
print sess.run(b)
顺便熟悉一下tf两个session的用法
import tensorflow as tf
a = tf.Variable([1,0,0,1,0])
b = tf.cast(a,dtype = bool)
with tf.Session() as sess:
sess = tf.Session()
tf.initialize_all_variables().run(session = sess)
b = sess.run(b)
print b
再来一个小例子,将浮点型小数转换成整数的例子
import tensorflow as tf
a = tf.Variable([1,0,0,1,0])
b = tf.cast(a,dtype = bool)
a = tf.Variable([2.6,4.8])
b = tf.cast(a,dtype = tf.int8)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
print sess.run(b)