tf.transpose()函数介绍和示例
tf.transpose(X, perm=None, name=‘transpose’, conjugate=False)
释义:交换维度
- X:需要变换的张量
- perm:新的维度序列
- name:(可选)操作名称
- conjugate:(可选),设置为True,等价于 tf.conj(tf.transpose()),为维度倒序交换,且共轭
示例1:二维
import tensorflow as tf
import numpy as np
A = np.array([[1, 2, 3], [4, 5, 6]])
X = tf.transpose(A, [1, 0]) # 相当于转置
with tf.Session() as sess:
print('original:\n',A)
print('tranpose:\n',sess.run(X))
original:
[[1 2 3]
[4 5 6]]
tranpose:
[[1 4]
[2 5]
[3 6]]
示例2:三维
import tensorflow as tf
import numpy as np
A = np.arange(12).reshape([2,3,2])
X = tf.transpose(A,[0,2,1]) # 交换维度 1 和维度 2
Y = tf.transpose(A,[1,0,2]) # 交换维度 0 和维度 1
with tf.Session() as sess:
print('original:\n', A)
print('A.shape:', A.shape)
print('='*30)
print('transpose [0,2,1]:\n', sess.run(X))
print('X.shape:', X.shape)
print('='*30)
print('transpose [1,0,2]:\n', sess.run(Y))
print('Y.shape:', Y.shape)
original:
[[[ 0 1]
[ 2 3]
[ 4 5]]
[[ 6 7]
[ 8 9]
[10 11]]]
A.shape: (2, 3, 2)
==============================
transpose [0,2,1]:
[[[ 0 2 4]
[ 1 3 5]]
[[ 6 8 10]
[ 7 9 11]]]
X.shape: (2, 2, 3)
==============================
transpose [1,0,2]:
[[[ 0 1]
[ 6 7]]
[[ 2 3]
[ 8 9]]
[[ 4 5]
[10 11]]]
Y.shape: (3, 2, 2)
示例3:conjugate
-
实数情况下,维度倒序交换
import tensorflow as tf import numpy as np A = tf.constant(np.arange(120).reshape([2,3,4,5])) X = tf.transpose(A, conjugate=True) # conjugate=True,维度倒序交换 X1 = tf.conj(tf.transpose(A)) # 同理,tf.conj(tf.transpose()) X2 = tf.transpose(A, [3,2,1,0]) # 同理,维度 [3,2,1,0] with tf.Session() as sess: print('A.shape:', A.shape) print('X.shape:', X.shape) print('X1.shape:', X1.shape) print('X2.shape:', X2.shape)
A.shape: (2, 3, 4, 5) X.shape: (5, 4, 3, 2) X1.shape: (5, 4, 3, 2) X2.shape: (5, 4, 3, 2)
-
复数情况下,复数共轭,且维度倒序交换
import tensorflow as tf A = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j, 6 + 6j]]) X = tf.transpose(A, conjugate=True) # conjugate=True,复数共轭,维度倒序 X_ = tf.conj(tf.transpose(A)) # 同理,tf.conj(tf.transpose()) with tf.Session() as sess: print('A.shape:', A.shape) print(sess.run(A)) print('='*30) print('X.shape:', X.shape) print(sess.run(X)) print('='*30) print('X_.shape:', X_.shape) print(sess.run(X_))
A.shape: (2, 3) [[1.+1.j 2.+2.j 3.+3.j] [4.+4.j 5.+5.j 6.+6.j]] ============================== X.shape: (3, 2) [[1.-1.j 4.-4.j] [2.-2.j 5.-5.j] [3.-3.j 6.-6.j]] ============================== X_.shape: (3, 2) [[1.-1.j 4.-4.j] [2.-2.j 5.-5.j] [3.-3.j 6.-6.j]]