环境: tensorfow 2.*
def concatenate(inputs, axis=-1, **kwargs):
axis=n表示从第n个维度进行拼接,对于一个三维矩阵,axis的取值可以为[-3, -2, -1, 0, 1, 2]。
代码
import numpy as np
import tensorflow as tf
t1 = tf.Variable(np.array([[[1, 2], [2, 3]], [[4, 4], [5, 3]]]))
t2 = tf.Variable(np.array([[[7, 4], [8, 4]], [[2, 10], [15, 11]]]))
d0 = tf.keras.layers.concatenate([t1, t2], axis=0)
d1 = tf.keras.layers.concatenate([t1, t2], axis=1)
d2 = tf.keras.layers.concatenate([t1, t2], axis=2)
d3 = tf.keras.layers.concatenate([t1, t2], axis=-1)
print(d0)
print(d1)
print(d2)
print(d3)
结果:
tf.Tensor(
[[[ 1 2]
[ 2 3]]
[[ 4 4]
[ 5 3]]
[[ 7 4]
[ 8 4]]
[[ 2 10]
[15 11]]], shape=(4, 2, 2), dtype=int32)
tf.Tensor(
[[[ 1 2]
[ 2 3]
[ 7 4]
[ 8 4]]
[[ 4 4]
[ 5 3]
[ 2 10]
[15 11]]], shape=(2, 4, 2), dtype=int32)
tf.Tensor(
[[[ 1 2 7 4]
[ 2 3 8 4]]
[[ 4 4 2 10]
[ 5 3 15 11]]], shape=(2, 2, 4), dtype=int32)
tf.Tensor(
[[[ 1 2 7 4]
[ 2 3 8 4]]
[[ 4 4 2 10]
[ 5 3 15 11]]], shape=(2, 2, 4), dtype=int32)
Process finished with exit code 0