keras中concatenate源代码如下:
def concatenate(tensors, axis=-1):
"""Concatenates a list of tensors alongside the specified axis.
# Arguments
tensors: list of tensors to concatenate.
axis: concatenation axis.
# Returns
A tensor.
"""
if axis < 0:
rank = ndim(tensors[0])
if rank:
axis %= rank
else:
axis = 0
if py_all([is_sparse(x) for x in tensors]):
return tf.sparse_concat(axis, tensors)
else:
return tf.concat([to_dense(x) for x in tensors], axis)
可以看出keras的concatenate()函数是披了外壳的tf.concat()。不过用法没有tf.concat()那么复杂。对tf.concat()解释可以看我的另一篇博文《tf.concat()详解》,如果只想了解concatenate的用法,可以不用移步。
axis=n表示从第n个维度进行拼接,对于一个三维矩阵,axis的取值可以为[-3, -2, -1, 0, 1, 2]。虽然keras用模除允许axis的取值可以在这个范围之外,但不建议那么用。
可以通过如下小段代码来理解:
import numpy as np
import cv2
import keras.backend as K
import tensorflow as tf
t1 = K.variable(np.array([[[1, 2], [2, 3]], [[4, 4], [5, 3]]]))
t2 = K.variable(np.array([[[7, 4], [8, 4]], [[2, 10], [15, 11]]]))
d0 = K.concatenate([t1 , t2] , axis=-2)
d1 = K.concatenate([t1 , t2] , axis=1)
d2 = K.concatenate([t1 , t2] , axis=-1)
d3 = K.concatenate([t1 , t2] , axis=2)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(d0))
print(sess.run(d1))
print(sess.run(d2))
print(sess.run(d3))
axis=-2,意思是从倒数第2个维度进行拼接,对于三维矩阵而言,这就等同于axis=1。
axis=-1,意思是从倒数第1个维度进行拼接,对于三维矩阵而言,这就等同于axis=2。
输出如下:
d0:
[[[ 1. 2.]
[ 2. 3.]
[ 7. 4.]
[ 8. 4.]]
[[ 4. 4.]
[ 5. 3.]
[ 2. 10.]
[ 15. 11.]]]
d1:
[[[ 1. 2.]
[ 2. 3.]
[ 7. 4.]
[ 8. 4.]]
[[ 4. 4.]
[ 5. 3.]
[ 2. 10.]
[ 15. 11.]]]
d2:
[[[ 1. 2. 7. 4.]
[ 2. 3. 8. 4.]]
[[ 4. 4. 2. 10.]
[ 5. 3. 15. 11.]]]
d3:
[[[ 1. 2. 7. 4.]
[ 2. 3. 8. 4.]]
[[ 4. 4. 2. 10.]
[ 5. 3. 15. 11.]]]