1、Reshape
函数的作用是将tensor变换为参数shape的形式,其中shape为一个列表形式,特殊的一点是列表中可以存在-1
-1代表的含义是不用我们自己指定这一维的大小,函数会自动计算,但列表只能存在一个-1。(如果存在多个-1,就是一个存在多解的方程)
a = tf.random.normal([4, 28,28, 3])
a.shape, a.ndim
Out[74]: (TensorShape([4, 28, 28, 3]), 4)
tf.reshape(a, [4, 784, 3]).shape
Out[76]: TensorShape([4, 784, 3])
tf.reshape(a, [4, -1, 3]).shape
Out[77]: TensorShape([4, 784, 3])
tf.reshape(a, [4, 784*3]).shape
Out[78]: TensorShape([4, 2352])
tf.reshape(a, [4, -1]).shape
Out[79]: TensorShape([4, 2352])
# 恢复
a = tf.random.normal([4, 28,28, 3])
tf.reshape(tf.reshape(a, [4, -1]), [4, 28, 28, 3]).shape
Out[92]: TensorShape([4, 28, 28, 3])
tf.reshape(tf.reshape(a, [4, -1]), [4, 14, 56, 3]).shape
Out[93]: TensorShape([4, 14, 56, 3])
tf.reshape(tf.reshape(a, [4, -1]), [4, 1, 784, 3]).shape
Out[94]: TensorShape([4, 1, 784, 3])
2、转置 tf.transpose
将a进行转置,并且根据perm参数重新排列输出维度.
tf.transpose
(
a,
perm=None,
name='transpose',
conjugate=False
)
a - 表示的是需要变换的张量
perm - a的新的维度序列
name - 操作的名字,可选的
conjugate - 可选的,设置成True,那么就等于tf.conj(tf.transpose(input)),用的太少了
注:perm-控制转置的操作,perm = [0, 1, 3, 2]表示,把将要转置的第0和第1维度不变,将第2和第3维度进行转置。
a = tf.random.normal((4,3,2,1))
a.shape
Out[96]: TensorShape([4, 3, 2, 1])
tf.transpose(a).shape
Out[98]: TensorShape([1, 2, 3, 4])
tf.transpose(a, perm = [0, 1, 3, 2]).shape
Out[100]: TensorShape([4, 3, 1, 2])
注:pytorch的数据格式一般是[b, 3, h, w], 而tensorFlow的数据格式一般是[b, h, w, 3],所以数据传递时,需要做格式的转换。
a = tf.random.normal([4, 28, 28, 3]) # 一个tensor格式
tf.transpose(a, [0, 2, 1, 3]).shape
Out[103]: TensorShape([4, 28, 28, 3])
tf.transpose(a, [0, 3, 2, 1]).shape
Out[104]: TensorShape([4, 3, 28, 28])
tf.transpose(a, [0, 3, 1, 2]).shape # 转换成pytorch的数据格式
Out[105]: TensorShape([4, 3, 28, 28])
3、增减一个维度
- 增加一个维度tf.expand_dims
注:当axis为正数时,在tensor正向对应维度的左边增加一个维度
当axis为负数时,在tensor反向对应维度的右边增加一个维度
a = tf.random.normal([4, 35, 8])
tf.expand_dims(a, axis = 0).shape
Out[107]: TensorShape([1, 4, 35, 8])
tf.expand_dims(a, axis = 3).shape
Out[108]: TensorShape([4, 35, 8, 1])
tf.expand_dims(a, axis = -1).shape
Out[110]: TensorShape([4, 35, 8, 1])
tf.expand_dims(a, axis = -4).shape
Out[111]: TensorShape([1, 4, 35, 8])
- 减少维度 tf.squeeze
注,只能去掉shape = 1的维度,如[4, 35, 8, 1], 只能去掉最后为1的维度, axis可以指定值为1的维度
tf.squeeze(tf.zeros([1, 2, 1, 1, 3])).shape
Out[112]: TensorShape([2, 3])
tf.squeeze(a, axis = 0).shape
Out[114]: TensorShape([2, 1, 3])
tf.squeeze(a, axis = 2).shape
Out[115]: TensorShape([1, 2, 3])
tf.squeeze(a, axis = 1).shape # 报错
tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[1], expected a dimension of 1, got 2 [Op:Squeeze]
tf.squeeze(a, axis = -2).shape
Out[117]: TensorShape([1, 2, 3])
tf.squeeze(a, axis = -4).shape
Out[118]: TensorShape([2, 1, 3])