网上有很多关于这个的解释,感觉解释的不清楚
解释:删除维度为1的张张量,举例说明,如果一个张量,shape为(100,1,1,100,3) 经过这个函数处理之后就变成了(100,100,3)
import tensorflow as tf
import numpy as np
data = np.ones((100,1,1,100,3))
x = tf.placeholder(tf.float32,[100,1,1,100,3])
y = tf.squeeze(x)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result = sess.run(y,feed_dict={x:data})
print(result.shape)
# (100, 100, 3)
import tensorflow as tf
import numpy as np
data = np.ones((100,1,1,100,3))
x = tf.placeholder(tf.float32,[100,1,1,100,3])
y = tf.squeeze(x,[1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result = sess.run(y,feed_dict={x:data})
print(result.shape)
# (100,1,100,3)
解释一下:
tf.squeeze(input, axis=None, name=None, squeeze_dims=None)
如果加了axis 则需要指明是哪一个维度,设置哪一个,则删除哪一个,如果不设置,则默认全部删除1维度的。