tensorflow中的tf.unstack():
解释:这是一个对矩阵进行分解的函数,以下为关键参数解释:
value:代表需要分解的矩阵变量(其实就是一个多维数组,一般为二维);
axis:指明对矩阵的哪个维度进行分解。
def unstack(value, num=None, axis=0, name="unstack"):
'''For example, given a tensor of shape `(A, B, C, D)`;
If `axis == 0` then the i'th tensor in `output` is the slice
`value[i, :, :, :]` and each tensor in `output` will have shape `(B, C, D)`.
(Note that the dimension unpacked along is gone, unlike `split`).
If `axis == 1` then the i'th tensor in `output` is the slice
`value[:, i, :, :]` and each tensor in `output` will have shape `(A, C, D)`.
Etc.'''
'''Args:
value: A rank `R > 0` `Tensor` to be unstacked.
num: An `int`. The length of the dimension `axis`. Automatically inferred
if `None` (the default).
axis: An `int`. The axis to unstack along. Defaults to the first
dimension. Negative values wrap around, so the valid range is `[-R, R)`.
name: A name for the operation (optional).'''
详解 -- 维度理解:
import tensorflow as tf
import numpy as np
sess = tf.Session()
sess.run(tf.global_variables_initializer())
ust_data = np.arange(1, 121).reshape([2, 3, 4, 5])
print('ust_data: \n', ust_data)
print('ust_data.shape: \n', ust_data.shape)
# ust_data:
# [[[[ 1 2 3 4 5]
# [ 6 7 8 9 10]
# [ 11 12 13 14 15]
# [ 16 17 18 19 20]]
#
# [[ 21 22 23 24 25]
# [ 26 27 28 29 30]
# [ 31 32 33 34 35]
# [ 36 37 38 39 40]]
#
# [[ 41 42 43 44 45]
# [ 46 47 48 49 50]
# [ 51 52 53 54 55]
# [ 56 57 58 59 60]]]
#
#
# [[[ 61 62 63 64 65]
# [ 66 67 68 69 70]
# [ 71 72 73 74 75]
# [ 76 77 78 79 80]]
#
# [[ 81 82 83 84 85]
# [ 86 87 88 89 90]
# [ 91 92 93 94 95]
# [ 96 97 98 99 100]]
#
# [[101 102 103 104 105]
# [106 107 108 109 110]
# [111 112 113 114 115]
# [116 117 118 119 120]]]]
# ust_data.shape:
# (2, 3, 4, 5)
print('======================================')
ust_0 = tf.unstack(ust_data, axis=0)
ust_0 = sess.run(ust_0)
print('ust_0: ', ust_0)
print('ust_0.shape: ', len(ust_0), ' * ', ust_0[0].shape) # 2 * (3, 4, 5) ==> (2, 3, 4, 5)
# ust_0: [array([[[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [16, 17, 18, 19, 20]],
#
# [[21, 22, 23, 24, 25],
# [26, 27, 28, 29, 30],
# [31, 32, 33, 34, 35],
# [36, 37, 38, 39, 40]],
#
# [[41, 42, 43, 44, 45],
# [46, 47, 48, 49, 50],
# [51, 52, 53, 54, 55],
# [56, 57, 58, 59, 60]]]), array([[[ 61, 62, 63, 64, 65],
# [ 66, 67, 68, 69, 70],
# [ 71, 72, 73, 74, 75],
# [ 76, 77, 78, 79, 80]],
#
# [[ 81, 82, 83, 84, 85],
# [ 86, 87, 88, 89, 90],
# [ 91, 92, 93, 94, 95],
# [ 96, 97, 98, 99, 100]],
#
# [[101, 102, 103, 104, 105],
# [106, 107, 108, 109, 110],
# [111, 112, 113, 114, 115],
# [116, 117, 118, 119, 120]]])]
# ust_0.shape: 2 * (3, 4, 5)
print('======================================')
ust_1 = tf.unstack(ust_data, axis=1)
ust_1 = sess.run(ust_1)
print('ust_1: ', ust_1)
print('ust_1.length: ', len(ust_1))
print('ust_1.shape: ', len(ust_1), ' * ', ust_1[0].shape) # 3 * (2, 4, 5) ==> (3, 2, 4, 5)
# ust_1: [array([[[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [16, 17, 18, 19, 20]],
#
# [[61, 62, 63, 64, 65],
# [66, 67, 68, 69, 70],
# [71, 72, 73, 74, 75],
# [76, 77, 78, 79, 80]]]), array([[[ 21, 22, 23, 24, 25],
# [ 26, 27, 28, 29, 30],
# [ 31, 32, 33, 34, 35],
# [ 36, 37, 38, 39, 40]],
#
# [[ 81, 82, 83, 84, 85],
# [ 86, 87, 88, 89, 90],
# [ 91, 92, 93, 94, 95],
# [ 96, 97, 98, 99, 100]]]), array([[[ 41, 42, 43, 44, 45],
# [ 46, 47, 48, 49, 50],
# [ 51, 52, 53, 54, 55],
# [ 56, 57, 58, 59, 60]],
#
# [[101, 102, 103, 104, 105],
# [106, 107, 108, 109, 110],
# [111, 112, 113, 114, 115],
# [116, 117, 118, 119, 120]]])]
# ust_1.length: 3
# ust_1.shape: 3 * (2, 4, 5)
print('======================================')
ust_2 = tf.unstack(ust_data, axis=2)
ust_2 = sess.run(ust_2)
print('ust_2: ', ust_2)
print('ust_2.length: ', len(ust_2))
print('ust_2.shape: ', len(ust_2), ' * ', ust_2[0].shape) # 4 * (2, 3, 5) ==> (4, 2, 3, 5)
# ust_2: [array([[[ 1, 2, 3, 4, 5],
# [ 21, 22, 23, 24, 25],
# [ 41, 42, 43, 44, 45]],
#
# [[ 61, 62, 63, 64, 65],
# [ 81, 82, 83, 84, 85],
# [101, 102, 103, 104, 105]]]), array([[[ 6, 7, 8, 9, 10],
# [ 26, 27, 28, 29, 30],
# [ 46, 47, 48, 49, 50]],
#
# [[ 66, 67, 68, 69, 70],
# [ 86, 87, 88, 89, 90],
# [106, 107, 108, 109, 110]]]), array([[[ 11, 12, 13, 14, 15],
# [ 31, 32, 33, 34, 35],
# [ 51, 52, 53, 54, 55]],
#
# [[ 71, 72, 73, 74, 75],
# [ 91, 92, 93, 94, 95],
# [111, 112, 113, 114, 115]]]), array([[[ 16, 17, 18, 19, 20],
# [ 36, 37, 38, 39, 40],
# [ 56, 57, 58, 59, 60]],
#
# [[ 76, 77, 78, 79, 80],
# [ 96, 97, 98, 99, 100],
# [116, 117, 118, 119, 120]]])]
# ust_2.length: 4
# ust_2.shape: 4 * (2, 3, 5)
print('======================================')
ust_3 = tf.unstack(ust_data, axis=3)
ust_3 = sess.run(ust_3)
print('ust_3: ', ust_3)
print('ust_3.length: ', len(ust_3))
print('ust_3.shape: ', len(ust_3), ' * ', ust_3[0].shape) # 5 * (2, 3, 4) ==> (5, 2, 3, 4)
# ust_3: [array([[[ 1, 6, 11, 16],
# [ 21, 26, 31, 36],
# [ 41, 46, 51, 56]],
#
# [[ 61, 66, 71, 76],
# [ 81, 86, 91, 96],
# [101, 106, 111, 116]]]), array([[[ 2, 7, 12, 17],
# [ 22, 27, 32, 37],
# [ 42, 47, 52, 57]],
#
# [[ 62, 67, 72, 77],
# [ 82, 87, 92, 97],
# [102, 107, 112, 117]]]), array([[[ 3, 8, 13, 18],
# [ 23, 28, 33, 38],
# [ 43, 48, 53, 58]],
#
# [[ 63, 68, 73, 78],
# [ 83, 88, 93, 98],
# [103, 108, 113, 118]]]), array([[[ 4, 9, 14, 19],
# [ 24, 29, 34, 39],
# [ 44, 49, 54, 59]],
#
# [[ 64, 69, 74, 79],
# [ 84, 89, 94, 99],
# [104, 109, 114, 119]]]), array([[[ 5, 10, 15, 20],
# [ 25, 30, 35, 40],
# [ 45, 50, 55, 60]],
#
# [[ 65, 70, 75, 80],
# [ 85, 90, 95, 100],
# [105, 110, 115, 120]]])]
# ust_3.length: 5
# ust_3.shape: 5 * (2, 3, 4)
# 理解:
# tf.unstack 其实是将axis维度直接放到最前面
# 也和 tf.transpose 类似
# 总结:
# tf.unstack() 中 stacks = (维1,维2, 维3, 维4 )
# 当axis=0时, 就相当于tf.transpose(stacks, [0, 1, 2, 3])
# 当axis=1时, 就相当于tf.transpose(stacks, [1, 0, 2, 3])
# 当axis=2时, 就相当于tf.transpose(stacks, [2, 0, 1, 3])
# 当axis=3时, 就相当于tf.transpose(stacks, [3, 0, 1, 2])