TensorFlow 移除所有尺度为1的维度 tf.squeeze 的基本用法及实例代码

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

 

二、官方说明

从张量的形状中移除所有尺寸为1的维数。(弃用参数)

https://tensorflow.google.cn/api_docs/python/tf/squeeze

tf.squeeze(
    input,
    axis=None,
    name=None,
    squeeze_dims=None
)

参数:

input:要缩减维度的张量

axis:可选整型列表,默认为 [ ],如果指定了给参数,值域列表中指定的维度会被移除。维度所以从 0 开始,范围是 [- rank(input), rank(input)]。不能移除尺度不为 1 的维度,否则会报错!

name:可选参数,设置操作的名称

squeeze_dims:被移除的关键字参数,通过 axis 替代

 

返回:

包含输入 input 中的数据,但移除了所有尺度为 1 的维度的张量,和输入 input 的数据类型相同

 

三、实例

(1)尺度缩减的错误方式

>>> raw_tensor = tf.constant(value=[[[1,2,3],[4,5,6]]])
>>> raw_tensor
<tf.Tensor 'Const_1:0' shape=(1, 2, 3) dtype=int32>
>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[1])
Traceback (most recent call last):
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1628, in _create_c_op
    c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\ops\array_ops.py", line 2573, in squeeze
    return gen_array_ops.squeeze(input, axis, name)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\ops\gen_array_ops.py", line 10108, in squeeze
    "Squeeze", input=input, squeeze_dims=axis, name=name)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 3274, in create_op
    op_def=op_def)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1792, in __init__
    control_input_ops)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1631, in _create_c_op
    raise ValueError(str(e))
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].

ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].

即不能移除尺度不为 1 的维度

 

(2)尺度缩减的正确方式

>>> import tensorflow as tf


# 向量 (1, 3) --> 标量(3,)
# 移除尺度为 1 的第一个维度
>>> <tf.Tensor 'Const:0' shape=(1, 3) dtype=int32>
>>> raw_tensor = tf.constant(value=[[1,2,3]])
>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor
<tf.Tensor 'Squeeze:0' shape=(3,) dtype=int32>


# 矩阵 (1, 3, 3) --> 标量(3, 3)
# 移除尺度为 1 的第一个维度
>>> raw_tensor = tf.constant(value=[[[1,2,3],[4,5,6],[7,8,9]]])
>>> raw_tensor
<tf.Tensor 'Const:0' shape=(1, 3, 3) dtype=int32>
>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor
<tf.Tensor 'Squeeze:0' shape=(3, 3) dtype=int32>


# 矩阵 (1, 1, 3) --> 标量(3,)
# 移除尺度为 1 的前两个维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor
<tf.Tensor 'Const:0' shape=(1, 1, 3) dtype=int32>
>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor
<tf.Tensor 'Squeeze:0' shape=(3,) dtype=int32>


# 通过参数 axis 指定的一个要的尺度为 1 的维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor
<tf.Tensor 'Const:0' shape=(1, 1, 3) dtype=int32>
>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[1])
>>> squeezed_tensor
<tf.Tensor 'Squeeze_0:0' shape=(1, 3) dtype=int32>


# 通过参数 axis 指定的多个要的尺度为 1 的维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor
<tf.Tensor 'Const:0' shape=(1, 1, 3) dtype=int32>
>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[0,1])
>>> squeezed_tensor
<tf.Tensor 'Squeeze_3:0' shape=(3,) dtype=int32>
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn-WJW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值