感觉这个不常用,但用到了不会就很尴尬了!
增加维度主要是用tf.expand_dims()这个方法,降低维度用的是tf.squeeze()这个方法。下面我放上这两个函数的源代码,其实自己稍微看一下就懂了。
def expand_dims(input, axis=None, name=None, dim=None):
"""Inserts a dimension of 1 into a tensor's shape.
Given a tensor `input`, this operation inserts a dimension of 1 at the
dimension index `axis` of `input`'s shape. The dimension index `axis` starts
at zero; if you specify a negative number for `axis` it is counted backward
from the end.
This operation is useful if you want to add a batch dimension to a single
element. For example, if you have a single image of shape `[height, width,
channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`,
which will make the shape `[1, height, width, channels]`.
Other examples:
#这是给的具体的例子,0代表在第0维插入一维,1代表在1维的位置插入一维
```python
# 't' is a tensor of shape [2]
tf.shape(tf.expand_dims(t, 0)) # [1, 2]
tf.shape(tf.expand_dims(t, 1)) # [2, 1]
tf.shape(tf.expand_dims(t, -1)) # [2, 1]
# 't2' is a tensor of shape [2, 3, 5]
tf.shape(tf.expand_dims(t2, 0)) # [1, 2, 3, 5]
tf.shape(tf.expand_dims(t2, 2)) # [2, 3, 1, 5]
tf.shape(tf.expand_dims(t2, 3)) # [2, 3, 5, 1]
```
This operation requires that:
`-1-input.dims() <= dim <= input.dims()`
This operation is related to `squeeze()`, which removes dimensions of
size 1.
Args:
input: A `Tensor`.
axis: 0-D (scalar). Specifies the dimension index at which to
expand the shape of `input`. Must be in the range
`[-rank(input) - 1, rank(input)]`.
name: The name of the output `Tensor`.
dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
Returns:
A `Tensor` with the same data as `input`, but its shape has an additional
dimension of size 1 added.
Raises:
ValueError: if both `dim` and `axis` are specified.
"""
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
return gen_array_ops.expand_dims(input, axis, name)
另一个函数的代码:
def squeeze(input, axis=None, name=None, squeeze_dims=None):
# pylint: disable=redefined-builtin
"""Removes dimensions of size 1 from the shape of a tensor.
Given a tensor `input`, this operation returns a tensor of the same type with
all dimensions of size 1 removed. If you don't want to remove all size 1
dimensions, you can remove specific size 1 dimensions by specifying
`axis`.
For example:
```python
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t)) # [2, 3]
```
Or, to remove specific size 1 dimensions:
```python
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t, [2, 4])) # [1, 2, 3, 1]
```
Args:
input: A `Tensor`. The `input` to squeeze.
axis: An optional list of `ints`. Defaults to `[]`.
If specified, only squeezes the dimensions listed. The dimension
index starts at 0. It is an error to squeeze a dimension that is not 1.
Must be in the range `[-rank(input), rank(input))`.
name: A name for the operation (optional).
squeeze_dims: Deprecated keyword argument that is now axis.
Returns:
A `Tensor`. Has the same type as `input`.
Contains the same data as `input`, but has one or more dimensions of
size 1 removed.
Raises:
ValueError: When both `squeeze_dims` and `axis` are specified.
"""
axis = deprecation.deprecated_argument_lookup(
"axis", axis, "squeeze_dims", squeeze_dims)
if np.isscalar(axis):
axis = [axis]
return gen_array_ops.squeeze(input, axis, name)