tf.unstack() 详解 —》理解为主

 tensorflow中的tf.unstack():

解释:这是一个对矩阵进行分解的函数,以下为关键参数解释:

value:代表需要分解的矩阵变量(其实就是一个多维数组,一般为二维);

axis:指明对矩阵的哪个维度进行分解。

def unstack(value, num=None, axis=0, name="unstack"):
  '''For example, given a tensor of shape `(A, B, C, D)`;

  If `axis == 0` then the i'th tensor in `output` is the slice
    `value[i, :, :, :]` and each tensor in `output` will have shape `(B, C, D)`.
    (Note that the dimension unpacked along is gone, unlike `split`).

  If `axis == 1` then the i'th tensor in `output` is the slice
    `value[:, i, :, :]` and each tensor in `output` will have shape `(A, C, D)`.
  Etc.'''
    
    '''Args:
    value: A rank `R > 0` `Tensor` to be unstacked.
    num: An `int`. The length of the dimension `axis`. Automatically inferred
      if `None` (the default).
    axis: An `int`. The axis to unstack along. Defaults to the first
      dimension. Negative values wrap around, so the valid range is `[-R, R)`.
    name: A name for the operation (optional).'''

 详解 -- 维度理解:

import tensorflow as tf
import numpy as np

sess = tf.Session()
sess.run(tf.global_variables_initializer())

ust_data = np.arange(1, 121).reshape([2, 3, 4, 5])
print('ust_data: \n', ust_data)
print('ust_data.shape: \n', ust_data.shape)
# ust_data:
#  [[[[  1   2   3   4   5]
#    [  6   7   8   9  10]
#    [ 11  12  13  14  15]
#    [ 16  17  18  19  20]]
#
#   [[ 21  22  23  24  25]
#    [ 26  27  28  29  30]
#    [ 31  32  33  34  35]
#    [ 36  37  38  39  40]]
#
#   [[ 41  42  43  44  45]
#    [ 46  47  48  49  50]
#    [ 51  52  53  54  55]
#    [ 56  57  58  59  60]]]
#
#
#  [[[ 61  62  63  64  65]
#    [ 66  67  68  69  70]
#    [ 71  72  73  74  75]
#    [ 76  77  78  79  80]]
#
#   [[ 81  82  83  84  85]
#    [ 86  87  88  89  90]
#    [ 91  92  93  94  95]
#    [ 96  97  98  99 100]]
#
#   [[101 102 103 104 105]
#    [106 107 108 109 110]
#    [111 112 113 114 115]
#    [116 117 118 119 120]]]]
# ust_data.shape:
#  (2, 3, 4, 5)


print('======================================')
ust_0 = tf.unstack(ust_data, axis=0)
ust_0 = sess.run(ust_0)
print('ust_0: ', ust_0)
print('ust_0.shape: ', len(ust_0), ' * ',  ust_0[0].shape)      #  2 * (3, 4, 5) ==> (2, 3, 4, 5)
# ust_0:  [array([[[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10],
#         [11, 12, 13, 14, 15],
#         [16, 17, 18, 19, 20]],
#
#        [[21, 22, 23, 24, 25],
#         [26, 27, 28, 29, 30],
#         [31, 32, 33, 34, 35],
#         [36, 37, 38, 39, 40]],
#
#        [[41, 42, 43, 44, 45],
#         [46, 47, 48, 49, 50],
#         [51, 52, 53, 54, 55],
#         [56, 57, 58, 59, 60]]]), array([[[ 61,  62,  63,  64,  65],
#         [ 66,  67,  68,  69,  70],
#         [ 71,  72,  73,  74,  75],
#         [ 76,  77,  78,  79,  80]],
#
#        [[ 81,  82,  83,  84,  85],
#         [ 86,  87,  88,  89,  90],
#         [ 91,  92,  93,  94,  95],
#         [ 96,  97,  98,  99, 100]],
#
#        [[101, 102, 103, 104, 105],
#         [106, 107, 108, 109, 110],
#         [111, 112, 113, 114, 115],
#         [116, 117, 118, 119, 120]]])]
# ust_0.shape:  2  *  (3, 4, 5)

print('======================================')
ust_1 = tf.unstack(ust_data, axis=1)
ust_1 = sess.run(ust_1)
print('ust_1: ', ust_1)
print('ust_1.length: ', len(ust_1))
print('ust_1.shape: ', len(ust_1), ' * ', ust_1[0].shape)       # 3 * (2, 4, 5) ==> (3, 2, 4, 5)
# ust_1:  [array([[[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10],
#         [11, 12, 13, 14, 15],
#         [16, 17, 18, 19, 20]],
#
#        [[61, 62, 63, 64, 65],
#         [66, 67, 68, 69, 70],
#         [71, 72, 73, 74, 75],
#         [76, 77, 78, 79, 80]]]), array([[[ 21,  22,  23,  24,  25],
#         [ 26,  27,  28,  29,  30],
#         [ 31,  32,  33,  34,  35],
#         [ 36,  37,  38,  39,  40]],
#
#        [[ 81,  82,  83,  84,  85],
#         [ 86,  87,  88,  89,  90],
#         [ 91,  92,  93,  94,  95],
#         [ 96,  97,  98,  99, 100]]]), array([[[ 41,  42,  43,  44,  45],
#         [ 46,  47,  48,  49,  50],
#         [ 51,  52,  53,  54,  55],
#         [ 56,  57,  58,  59,  60]],
#
#        [[101, 102, 103, 104, 105],
#         [106, 107, 108, 109, 110],
#         [111, 112, 113, 114, 115],
#         [116, 117, 118, 119, 120]]])]
# ust_1.length:  3
# ust_1.shape:  3  *  (2, 4, 5)

print('======================================')
ust_2 = tf.unstack(ust_data, axis=2)
ust_2 = sess.run(ust_2)
print('ust_2: ', ust_2)
print('ust_2.length: ', len(ust_2))
print('ust_2.shape: ', len(ust_2), ' * ', ust_2[0].shape)       # 4 * (2, 3, 5) ==> (4, 2, 3, 5)
# ust_2:  [array([[[  1,   2,   3,   4,   5],
#         [ 21,  22,  23,  24,  25],
#         [ 41,  42,  43,  44,  45]],
#
#        [[ 61,  62,  63,  64,  65],
#         [ 81,  82,  83,  84,  85],
#         [101, 102, 103, 104, 105]]]), array([[[  6,   7,   8,   9,  10],
#         [ 26,  27,  28,  29,  30],
#         [ 46,  47,  48,  49,  50]],
#
#        [[ 66,  67,  68,  69,  70],
#         [ 86,  87,  88,  89,  90],
#         [106, 107, 108, 109, 110]]]), array([[[ 11,  12,  13,  14,  15],
#         [ 31,  32,  33,  34,  35],
#         [ 51,  52,  53,  54,  55]],
#
#        [[ 71,  72,  73,  74,  75],
#         [ 91,  92,  93,  94,  95],
#         [111, 112, 113, 114, 115]]]), array([[[ 16,  17,  18,  19,  20],
#         [ 36,  37,  38,  39,  40],
#         [ 56,  57,  58,  59,  60]],
#
#        [[ 76,  77,  78,  79,  80],
#         [ 96,  97,  98,  99, 100],
#         [116, 117, 118, 119, 120]]])]
# ust_2.length:  4
# ust_2.shape:  4  *  (2, 3, 5)

print('======================================')
ust_3 = tf.unstack(ust_data, axis=3)
ust_3 = sess.run(ust_3)
print('ust_3: ', ust_3)
print('ust_3.length: ', len(ust_3))
print('ust_3.shape: ', len(ust_3), ' * ', ust_3[0].shape)       #  5 * (2, 3, 4) ==> (5, 2, 3, 4)
# ust_3:  [array([[[  1,   6,  11,  16],
#         [ 21,  26,  31,  36],
#         [ 41,  46,  51,  56]],
#
#        [[ 61,  66,  71,  76],
#         [ 81,  86,  91,  96],
#         [101, 106, 111, 116]]]), array([[[  2,   7,  12,  17],
#         [ 22,  27,  32,  37],
#         [ 42,  47,  52,  57]],
#
#        [[ 62,  67,  72,  77],
#         [ 82,  87,  92,  97],
#         [102, 107, 112, 117]]]), array([[[  3,   8,  13,  18],
#         [ 23,  28,  33,  38],
#         [ 43,  48,  53,  58]],
#
#        [[ 63,  68,  73,  78],
#         [ 83,  88,  93,  98],
#         [103, 108, 113, 118]]]), array([[[  4,   9,  14,  19],
#         [ 24,  29,  34,  39],
#         [ 44,  49,  54,  59]],
#
#        [[ 64,  69,  74,  79],
#         [ 84,  89,  94,  99],
#         [104, 109, 114, 119]]]), array([[[  5,  10,  15,  20],
#         [ 25,  30,  35,  40],
#         [ 45,  50,  55,  60]],
#
#        [[ 65,  70,  75,  80],
#         [ 85,  90,  95, 100],
#         [105, 110, 115, 120]]])]
# ust_3.length:  5
# ust_3.shape:  5  *  (2, 3, 4)

# 理解:
    # tf.unstack 其实是将axis维度直接放到最前面
    # 也和 tf.transpose 类似
#      总结:
#     tf.unstack() 中 stacks = (维1,维2, 维3, 维4 )
#     当axis=0时, 就相当于tf.transpose(stacks, [0, 1, 2, 3])
#     当axis=1时, 就相当于tf.transpose(stacks, [1, 0, 2, 3])
#     当axis=2时, 就相当于tf.transpose(stacks, [2, 0, 1, 3])
#     当axis=3时, 就相当于tf.transpose(stacks, [3, 0, 1, 2])

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值