tf.transpose():用于转置的操作
def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
- perm :控制转置的操作,以perm = [0,1,2] 3个维度的数组为例, 0–代表的是最外层的一维, 1–代表外向内数第二维, 2–代表最内层的一维,这种perm是默认的值.如果换成[1,0,2],就是把最外层的两维进行转置,比如原来是2乘3乘4,经过[1,0,2]的转置维度将会变成3乘2乘4
示例:
import tensorflow as tf input_x = [ [ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12] ], [ [13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24] ] ] print("tf.transpose(input_x, (1, 0, 2)):", tf.transpose(input_x, (1, 0, 2))) print("tf.transpose(input_x, (0, 2, 1)):", tf.transpose(input_x, (0, 2, 1)))
输出:
tf.transpose(input_x, (1, 0, 2)): tf.Tensor( [[[ 1 2 3 4] [13 14 15 16]] [[ 5 6 7 8] [17 18 19 20]] [[ 9 10 11 12] [21 22 23 24]]], shape=(3, 2, 4), dtype=int32) tf.transpose(input_x, (0, 2, 1)): tf.Tensor( [[[ 1 5 9] [ 2 6 10] [ 3 7 11] [ 4 8 12]] [[13 17 21] [14 18 22] [15 19 23] [16 20 24]]], shape=(2, 4, 3), dtype=int32)