tf.stack()、tf.unstack()函数介绍和示例
1. tf.stack([A, B] axis=0)
释义:矩阵拼接
- A,B,输入张量
- axis,指定拼接维度,默认为0。
示例1:二维数据中,维度 0 处拼接
import tensorflow as tf
A = [[1, 2, 3],
[4, 5, 6]]
B = [[10, 20, 30],
[40, 50, 60]]
X = tf.stack([A, B], axis=0) # 0维拼接
with tf.Session() as sess:
print(sess.run(X))
[[[ 1 2 3]
[ 4 5 6]]
[[10 20 30]
[40 50 60]]]
示例2:二维数据中,维度 1 处拼接
import tensorflow as tf
A = [[1, 2, 3],
[4, 5, 6]]
B = [[10, 20, 30],
[40, 50, 60]]
X = tf.stack([A, B], axis=1) # 1维拼接
with tf.Session() as sess:
print(sess.run(X))
[[[ 1 2 3]
[10 20 30]]
[[ 4 5 6]
[40 50 60]]]
示例3:二维数据中,维度 2 处拼接
import tensorflow as tf
A = [[1, 2, 3],
[4, 5, 6]]
B = [[10, 20, 30],
[40, 50, 60]]
X = tf.stack([A, B], axis=2) # 2维拼接
with tf.Session() as sess:
print(sess.run(X))
[[[ 1 10]
[ 2 20]
[ 3 30]]
[[ 4 40]
[ 5 50]
[ 6 60]]]
2. tf.unstack(A, axis=0)
释义:拆分矩阵
- A,输入张量
- axis,指定维度,默认为 0。二维数据中,若为 0,则按行拆分;若为 1,则按列拆分
示例1:二维数据按行拆分
import tensorflow as tf
A = [[1, 2, 3],
[4, 5, 6]]
X = tf.unstack(A, axis=0)
with tf.Session() as sess:
print(sess.run(X))
[array([1, 2, 3]), array([4, 5, 6])]
示例2:二维数据按列拆分
import tensorflow as tf
A = [[1, 2, 3],
[4, 5, 6]]
X = tf.unstack(A, axis=1)
with tf.Session() as sess:
print(sess.run(X))
[array([1, 4]), array([2, 5]), array([3, 6])]