官方API文档:
transpose(
a, # 输入张量
perm=None, # 转置规则,后面详细介绍
name='transpose'
)
Args:
a: A Tensor.
perm: A permutation of the dimensions of a.
name: A name for the operation (optional).
Returns:
A transposed Tensor.
看一下官方给出的例子:
原本张量x维度:2*3,转置后为3*2,使用perm时候,参数列表表示张量从外向里的一维,0表示最外层,1表示从外向里的第1层,由于x这个张量只有两层,设置perm=[1, 0]就相当于将原本的2*3转为3*2:
# 'x' is [[1 2 3]
# [4 5 6]]
tf.transpose(x) ==> [[1 4]
[2 5]
[3 6]]
# Equivalently
tf.transpose(x, perm=[1, 0]) ==> [[1 4]
[2 5]
[3 6]]
如果张量维数大于2,perm这个参数作用就更明显了:输入张量x的维度是2*2*3,perm=[0, 2, 1],相当于转化为2*3*2,最外层那维不变,内两维转置:
# 'perm' is more useful for n-dimensional tensors, for n > 2
# 'x' is [[[1 2 3]
# [4 5 6]]
# [[7 8 9]
# [10 11 12]]]
# Take the transpose of the matrices in dimension-0
tf.transpose(x, perm=[0, 2, 1]) ==> [[[1 4]
[2 5]
[3 6]]
[[7 10]
[8 11]
[9 12]]]
再一般化些,如果输入x的各层维数都不同:
x = [
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
],
[
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]
]
]
# 输入x的维数为:2*3*4
output1=tf.transpose(x, perm=[1,0,2]) # 转置后,维数为:3*2*4
output1=
[
[
[ 1 2 3 4]
[13 14 15 16]
]
[
[ 5 6 7 8]
[17 18 19 20]
]
[
[ 9 10 11 12]
[21 22 23 24]
]
]
output2=tf.transpose(x, perm=[0,2,1]) # 转置后,维数为:2*4*3
output2=
[
[
[ 1 5 9]
[ 2 6 10]
[ 3 7 11]
[ 4 8 12]
]
[
[13 17 21]
[14 18 22]
[15 19 23]
[16 20 24]
]
]