定义方式:
def cast(x, dtype, name=None)
解释:
修改张量数据的数据类型
x:待修改数据类型的数据
dtype:目标数据类型
name:可选参数,即定义操作的名称
cast在tensorflow的代码位置:tensorflow/python/ops/math_ops.py
具体的定义如下所示:
def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
The operation casts `x` (in case of `Tensor`) or `x.values`
(in case of `SparseTensor` or `IndexedSlices`) to `dtype`.
For example:
```python
x = tf.constant([1.8, 2.2], dtype=tf.float32)
tf.cast(x, tf.int32) # [1, 2], dtype=tf.int32
```
The operation supports data types (for `x` and `dtype`) of
`uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
`float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
In case of casting from complex types (`complex64`, `complex128`) to real
types, only the real part of `x` is returned. In case of casting from real
types to complex types (`complex64`, `complex128`), the imaginary part of the
returned value is set to `0`. The handling of complex types here matches the
behavior of numpy.
Args:
x: A `Tensor` or `SparseTensor` or `IndexedSlices` of numeric type. It could
be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`,
`int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`,
`bfloat16`.
dtype: The destination type. The list of supported dtypes is the same as
`x`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and
same type as `dtype`.
Raises:
TypeError: If `x` cannot be cast to the `dtype`.
"""
base_type = dtypes.as_dtype(dtype).base_dtype
if isinstance(x,
(ops.Tensor, _resource_variable_type)) and base_type == x.dtype:
return x
with ops.name_scope(name, "Cast", [x]) as name:
if isinstance(x, sparse_tensor.SparseTensor):
values_cast = cast(x.values, base_type, name=name)
x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
elif isinstance(x, ops.IndexedSlices):
values_cast = cast(x.values, base_type, name=name)
x = ops.IndexedSlices(values_cast, x.indices, x.dense_shape)
else:
# TODO(josh11b): If x is not already a Tensor, we could return
# ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
# allows some conversions that cast() can't do, e.g. casting numbers to
# strings.
x = ops.convert_to_tensor(x, name="x")
if x.dtype.base_dtype != base_type:
x = gen_math_ops.cast(x, base_type, name=name)
if x.dtype.is_complex and base_type.is_floating:
logging.warn("Casting complex to real discards imaginary part.")
return x
例子:
import tensorflow as tf
def tf_cast_1():
t1 = tf.Variable([1, 2, 3, 4, 5, 6])
t2 = tf.cast(t1, dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('t1', t1, sess.run(t1))
print('t2', t2, sess.run(t2))
tf_cast_1()
运行的结果如下所示: