代码及示例
#维度变换
import tensorflow as tf
import numpy as np
# reshape
a = tf.random.normal([4,28,28,3])
print(a.shape,a.ndim)
'''(4, 28, 28, 3) 4'''
print(tf.reshape(a, [4, 28 * 28, 3]).shape)
'''(4, 784, 3)'''
print(tf.reshape(a,[4,28*28*3]).shape)
'''(4, 2352)'''
print(tf.reshape(a,[4,-1]).shape)
'''(4, 2352)'''
# transpose ---- 重要参数:perm 默认是装置
a = tf.random.normal((4,3,2,1))
print(a.shape)
'''(4, 3, 2, 1)'''
print(tf.transpose(a).shape)
'''(1, 2, 3, 4)'''
print(tf.transpose(a,perm=[0,1,3,2]).shape)
'''(4, 3, 1, 2)'''
#实例
a = tf.constant([
[[1,2,3,4],[4,5,6,7]],
[[7,8,9,10],[10,11,12,13]],
[[1,2,3,4],[10,11,12,13]]
])
print(a)#shape=(3, 2, 4)
print(tf.transpose(a).shape)#(4, 2, 3)
print(tf.transpose(a,perm=[0,2,1]).shape)#(3, 4, 2)
# expand_dims 作用:用于增加一个维度
#重要参数:axis 维度
a = tf.random.normal([4,23,8])
print(a.shape)#(4, 23, 8)
print(tf.expand_dims(a, axis=0).shape)#(1, 4, 23, 8)
print(tf.expand_dims(a,axis=2).shape)#(4, 23, 1, 8)
print(tf.expand_dims(a,axis=-2).shape)#(4, 23, 1, 8)
# print(tf.expand_dims(a,axis=4).shape)#报错
#squeeze 作用:减少维度 !!!注意:只有一维的可以降维
a = tf.random.normal([4,2,1,3,8,1])
print(a.shape)#(4, 2, 1, 3, 8, 1)
print(tf.squeeze(a).shape)#(4, 2, 3, 8)
print(tf.squeeze(a,axis=2).shape)#(4, 2, 3, 8, 1)
总结
- reshape
改变维度 - transpose
重要参数:perm 默认是转置 - expand_dims
作用:用于增加一个维度 - squeeze
作用:减少维度
!!!注意:只有一维的可以降维
具体用法看代码