定义
def squeeze(input, axis=None, name=None, squeeze_dims=None)
该函数返回一个张量,这个张量是将原始input中所有维度为1的那些维都删掉的结果。
axis
可以用来指定要删掉的为1的维度,此处要注意指定的维度必须确保其是1,否则会报错。
用法
删掉所有维度为1的:
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t)) # [2, 3]
删掉指定维度为1的:
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t, [2, 4])) # [1, 2, 3, 1]
tf.squeeze和tf.expand_dims互为逆操作。