虽然用了很久的这个函数,只记得住维度的交换,但经常忘记转换前后tensor的具体变化,再次记录下。
tf.transpose()作为数组的转置函数,原型如下:
def transpose(a, perm=None, name="transpose"):
"""Transposes `a`. Permutes the dimensions according to `perm`
a 表示是传入的数组
perm:控制转置的操作,以perm = [0,1,2] 3个维度的数组为例, 0--代表的是最外层的一维, 1--代表外向内数第二维, 2--代表最内层的一维,这种perm是默认的值.现在以如下输入数组来理解这个函数和参数perm
input_x = [
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
],
[
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]
]
]
显然,这里input_x 是一个 2x3x4的一个tensor,假设perm = [1,0,2],就是将最外2层转置,得到tensor应该是 3x2x4的一个张量,我们将input_x抽象化,不管第3维度,则可表示为:
[
[
A,
B,
C
],
[
D,
E,
F,
]
]
这样就变成了2x3的tensor,类似于2x3的数组
[
A D
B E
C F
]
再将A-F换成具体的值,最终得到的张量是
[
[
[ 1 2 3 4]
[13 14 15 16]
]
[
[ 5 6 7 8]
[17 18 19 20]
]
[
[ 9 10 11 12]
[21 22 23 24]
]
]
这就是perm前两列交换的结果了。
我们再看,如果 perm=[0,2,1]说明要交换内层里面的两个维度,从原来的2x3x4变成2x4x3的张量,就不抽象化了,结果就是:
[
[
[ 1 5 9]
[ 2 6 10]
[ 3 7 11]
[ 4 8 12]
]
[
[13 17 21]
[14 18 22]
[15 19 23]
[16 20 24]
]
]
具体代码为:
import tensorflow as tf
input_x = [
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
],
[
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]
]
]
result = tf.transpose(input_x, perm=[0, 2, 1])
with tf.Session() as sess:
print(sess.run(result))
注意与tf.reshape()的区别,tf.reshape()是先将所有的维度展平,再按新的维度划分,维度中的数据发生了变化。tf.transpose()转换后,维度数据并没有发生变化,只是从另外一个角度去观察tensor (比如一个三维的tensor (理解为长方体),transpose()可以想象成从不同角度去观察,实质上这个三维tensor并没有变化)。