tf.stack() 详解

tensorflow用于矩阵拼接的方法:tf.stack()

个人参考感觉还不错的一个理解(tf.stack() 和 tf.concat()的区别):https://blog.csdn.net/Gai_Nothing/article/details/88416782

 

def stack(values, axis=0, name="stack"):
    """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
  Packs the list of tensors in `values` into a tensor with rank one higher than
  each tensor in `values`, by packing them along the `axis` dimension.
  Given a list of length `N` of tensors of shape `(A, B, C)`;
  if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
  if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
  Etc."""
 
    '''Args:
    values: A list of `Tensor` objects with the same shape and type.
    axis: An `int`. The axis to stack along. Defaults to the first dimension.
      Negative values wrap around, so the valid range is `[-(R+1), R+1)`.
    name: A name for this operation (optional).'''
 个人理解 ~ 测试:

import tensorflow as tf
import numpy as np
 
sess = tf.Session()
sess.run(tf.global_variables_initializer())
 
# stack and unstack
stack_data1, stack_data2 = np.arange(1, 31).reshape([2, 3, 5])
print('stack_data1: \n', stack_data1)
print('stack_data1.shape: \n', stack_data1.shape)
print('stack_data2: \n', stack_data2)
print('stack_data2.shape: \n', stack_data2.shape)
# stack_data1:
#  [[ 1  2  3  4  5]
#  [ 6  7  8  9 10]
#  [11 12 13 14 15]]
# stack_data1.shape:
#  (3, 5)
# stack_data2:
#  [[16 17 18 19 20]
#  [21 22 23 24 25]
#  [26 27 28 29 30]]
# stack_data2.shape:
#  (3, 5)
 
# 理解:
#     举例:当前两个个张量的维度均为:(维1,维2, 维3, 维4), 此时axis的取值范围为:[-5, 5)
#     所以输入 stacks = [stack_data1, stack_data2], st = tf.stack(stacks, axis=?)
#     此时:
#           stacks的维度为:(2,维1,维2, 维3, 维4 )   维度为5,所以输出维度也为5, axis取值就在[-5, 5)
#           当axis=0时, st维度为:(2, 维1, 维2, 维3, 维4)
#           当axis=1时, st维度为:(维1, 2,维2, 维3, 维4)
#           当axis=2时, st维度为:(维1, 维2, 2,维3, 维4)
#           当axis=3时, st维度为:(维1, 维2, 维3,2,维4)
#           当axis=4时, st维度为:(维1, 维2, 维3,维4,2)
 
#           当axis=-5时, st维度为:(2, 维1, 维2, 维3, 维4)
#           当axis=-4时, st维度为:(维1, 2,维2, 维3, 维4)
#           当axis=-3时, st维度为:(维1, 维2, 2,维3, 维4)
#           当axis=-2时, st维度为:(维1, 维2, 维3,2,维4)
#           当axis=-1时, st维度为:(维1, 维2, 维3,维4,2)
 
print('======================================')
st_0 = tf.stack([stack_data1, stack_data2], axis=0)     # 2 * (3, 5) ==> (2, 3, 5)
st_0 = sess.run(st_0)
print('st_0: \n', st_0)
print('st_0.shape: \n', st_0.shape)
# st_0:
#  [[[ 1  2  3  4  5]
#   [ 6  7  8  9 10]
#   [11 12 13 14 15]]
#
#  [[16 17 18 19 20]
#   [21 22 23 24 25]
#   [26 27 28 29 30]]]
# st_0.shape:
#  (2, 3, 5)
 
print('======================================')
st_1 = tf.stack([stack_data1, stack_data2], axis=1)     # 2 * (3, 5) ==> (3, 2, 5)
st_1 = sess.run(st_1)
print('st_1: \n', st_1)
print('st_1.shape: \n', st_1.shape)
# st_1:
#  [[[ 1  2  3  4  5]
#   [16 17 18 19 20]]
#
#  [[ 6  7  8  9 10]
#   [21 22 23 24 25]]
#
#  [[11 12 13 14 15]
#   [26 27 28 29 30]]]
# st_1.shape:
#  (3, 2, 5)
 
print('======================================')
st_2 = tf.stack([stack_data1, stack_data2], axis=2)     # 2 * (3, 5) ==> (3, 5, 2)
st_2 = sess.run(st_2)
print('st_2: \n', st_2)
print('st_2.shape: \n', st_2.shape)
# st_2:
#  [[[ 1 16]
#   [ 2 17]
#   [ 3 18]
#   [ 4 19]
#   [ 5 20]]
#
#  [[ 6 21]
#   [ 7 22]
#   [ 8 23]
#   [ 9 24]
#   [10 25]]
#
#  [[11 26]
#   [12 27]
#   [13 28]
#   [14 29]
#   [15 30]]]
# st_2.shape:
#  (3, 5, 2)
 
print('======================================')
st_1_ = tf.stack([stack_data1, stack_data2], axis=-1)     # 2 * (3, 5) ==>  (3, 5, 2)   等同于st_2
st_1_ = sess.run(st_1_)
print('st_1_: \n', st_1_)
print('st_1_.shape: \n', st_1_.shape)
# st_1:
#  [[[ 1 16]
#   [ 2 17]
#   [ 3 18]
#   [ 4 19]
#   [ 5 20]]
#
#  [[ 6 21]
#   [ 7 22]
#   [ 8 23]
#   [ 9 24]
#   [10 25]]
#
#  [[11 26]
#   [12 27]
#   [13 28]
#   [14 29]
#   [15 30]]]
# st_1.shape:
#  (3, 5, 2)
 
print('=================比较st_1, 和 transpose=====================')
print('st_1: \n', st_1)
transpose_test = sess.run(tf.transpose(st_0, [1, 0, 2]))
print('transpose_test: \n', transpose_test)
print('transpose_test == st_1: \n', transpose_test == st_1)
 
print('=================比较st_2, 和 transpose=====================')
print('st_2: \n', st_2)
transpose_test = sess.run(tf.transpose(st_0, [1, 2, 0]))
print('transpose_test: \n', transpose_test)
print('transpose_test == st_2: \n', transpose_test == st_2)
# 总结:
#     tf.stack() 中 stacks = (2,维1,维2, 维3, 维4 )
#     当axis=0时, 就相当于tf.transpose(stacks, [0, 1, 2, 3, 4])
#     当axis=1时, 就相当于tf.transpose(stacks, [1, 0, 2, 3, 4])
#     当axis=2时, 就相当于tf.transpose(stacks, [1, 2, 0, 3, 4])
#     当axis=3时, 就相当于tf.transpose(stacks, [1, 2, 3, 0, 4])
#     当axis=0时, 就相当于tf.transpose(stacks, [1, 2, 3, 4, 0])
 
 
# 4 维测试:
stack_data1, stack_data2 = np.arange(1, 121).reshape([2, 3, 4, 5]) # (2, 3, 4, 5)
st_ = tf.stack([stack_data1, stack_data2], axis=3)
st_0 = tf.stack([stack_data1, stack_data2], axis=0)
st_ = sess.run(st_)
st_0 = sess.run(st_0)
 
tr_ = tf.transpose(st_0, [1, 2, 3, 0])
tr_ = sess.run(tr_)
 
print('st_.shape: ', st_.shape)
print('st_: ', st_)
 
print('tr_.shape: ', tr_.shape)
print('tr_: ', tr_)
 
print(st_ == tr_)

————————————————
版权声明:本文为CSDN博主「feifeiyechuan」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/feifeiyechuan/article/details/89388103

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
TensorFlow 中的 `tf.stack` 和 `tf.concat` 都可以用于将多个张量拼接成一个张量,但它们的实现方式略有不同,具体如下: - `tf.concat`: 沿着一个指定的维度将多个张量拼接起来。例如,将两个形状为 `(3, 4)` 的张量沿着第一个维度拼接起来,得到一个形状为 `(6, 4)` 的张量。`tf.concat` 的实现方式是将多个张量在指定维度上直接拼接,因此要求各个输入张量在指定维度上大小相同。 - `tf.stack`: 沿着一个新的维度将多个张量堆叠起来。例如,将两个形状为 `(3, 4)` 的张量在第三个维度上堆叠起来,得到一个形状为 `(3, 4, 2)` 的张量。`tf.stack` 的实现方式是创建一个新的维度,并在这个维度上将各个输入张量堆叠起来,因此各个输入张量的大小可以不同,但在其它维度上的大小必须相同。 下面是具体的使用示例: ```python import tensorflow as tf # 定义两个张量 a = tf.constant([1, 2, 3]) b = tf.constant([4, 5, 6]) # 使用 tf.concat 将两个张量拼接成一个张量 c = tf.concat([a, b], axis=0) print(c) # 输出 [1 2 3 4 5 6] # 使用 tf.stack 将两个张量堆叠成一个张量 d = tf.stack([a, b], axis=1) print(d) # 输出 [[1 4] [2 5] [3 6]] ``` 在上面的例子中,我们首先定义了两个形状相同的张量 `a` 和 `b`。然后我们使用 `tf.concat` 将它们沿着第一个维度拼接起来,得到一个形状为 `(6,)` 的张量 `c`;接着使用 `tf.stack` 将它们在第二个维度上堆叠起来,得到一个形状为 `(3, 2)` 的张量 `d`。可以看到,`tf.concat` 和 `tf.stack` 的输出结果是不同的,这是因为它们的实现方式不同,使用时需要根据具体的需求选择合适的方法。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值