现列出代码,后面进行讲解
import numpy as np
import tensorflow as tf
a = tf.constant([[1,2],[3,4]])
b = tf.constant([[6,7],[8,9]])
c = tf.stack([a,b], axis=0)
d = tf.stack([a,b], axis=1)
e = tf.unstack(a, axis=0)
f = tf.unstack(a, axis=1)
with tf.Session() as sess:
print(sess.run(c))
print(sess.run(d))
print(sess.run(e))
print(sess.run(f))
结果如下:
c: [[[1 2]
[3 4]]
[[6 7]
[8 9]]]
d: [[[1 2]
[6 7]]
[[3 4]
[8 9]]]
e: [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
f: [array([1, 3], dtype=int32), array([2, 4], dtype=int32)]
1、c和d涉及到的是tf.stack的合并,先将要合并的对象用[]合并起来,如[a, b],然后axis表示对哪一个维度进行合并。
2、e和f设计到的是tf,unstack的分解,与tf.stack的用法正好相反,axis表示对哪一维度进行分解。