一:原型
concat(values, axis, name=“concat”)。简单理解即将传入的values(若干shape完全一样的N维张量)在指定的维度axis(0<= axis <= N-1)上进行拼接,并返回拼接后的张量。
二:代码分析
1:一维张量
a = tf.constant([1,2])
b = tf.constant([3, 4])
c = tf.concat(values=[a, b], axis=0)
with tf.Session() as sess:
print(sess.run(c))
如上面代码所示,定义了两个一维张量a和b,axis的可取值此时只能是0,c为拼接后的结果,运行可知,c=[1 2 3 4]。
a = tf.constant([1])
b = tf.constant([2])
c = tf.constant([3])
d = tf.constant([4])
cat = tf.concat(values=[a, b, c, d], axis=0)
with tf.Session() as sess:
print(sess.run(cat))
同理,以上代码运行后结果为[1 2 3 4]
2:二维张量
对于二维张量而言,axis的可选值包括0和1。
a = tf.constant([
[1, 2, 3],
[4, 5, 6]
])
b = tf.constant([
[7, 8, 9],
[3, 5, 8]
])
cat = tf.concat(values=[a, b], axis=0)
with tf.Session() as sess:
print(sess.run(cat))
如上代码所示,定义了两个shape为(2, 3)的张量a和b,并在第0维上进行concat操作,运行程序知cat的shape为(4, 3),cat的值为:
[[1 2 3]
[4 5 6]
[7 8 9]
[3 5 8]]
可知,axis=0时即在第0维,也就是在二维张量的行上进行了concat操作。将axis改为1后。运行程序知cat的shape为(2, 6),值为:
[[1 2 3 7 8 9]
[4 5 6 3 5 8]]
可知axis=1时相当于对张量a的列进行了扩展。
3:三维张量
对于三维张量而言,axis的可选值包括0、1、2。
a = tf.constant([
[
[1, 1, 1, 1],
[1, 1, 1, 1]
],
[
[2, 2, 2, 2],
[2, 2, 2, 2]
],
[
[3, 3, 3, 3],
[3, 3, 3, 3]
],
])
b = tf.constant([
[
[4, 4, 4, 4],
[4, 4, 4, 4]
],
[
[5, 5, 5, 5],
[5, 5, 5, 5]
],
[
[6, 6, 6, 6],
[6, 6, 6, 6]
],
])
cat = tf.concat(values=[a, b], axis=0)
with tf.Session() as sess:
print(sess.run(cat))
程序中定义了两个shape为(3, 2, 4)的三维张量,并在第0维上进行concat操作,运行程序后,cat的shape为(6, 2, 4),值为:
[[[1 1 1 1]
[1 1 1 1]]
[[2 2 2 2]
[2 2 2 2]]
[[3 3 3 3]
[3 3 3 3]]
[[4 4 4 4]
[4 4 4 4]]
[[5 5 5 5]
[5 5 5 5]]
[[6 6 6 6]
[6 6 6 6]]]
对于三维张量A而言,第0维即表示A中含有多少个二维张量B,由以上结果可知,在0维上拼接,相当于把b中所有的二维张量直接添加到a原有的二维张量后面。
将axis改为1,得到的cat的shape为(3, 4, 4),值为
[[[1 1 1 1]
[1 1 1 1]
[4 4 4 4]
[4 4 4 4]]
[[2 2 2 2]
[2 2 2 2]
[5 5 5 5]
[5 5 5 5]]
[[3 3 3 3]
[3 3 3 3]
[6 6 6 6]
[6 6 6 6]]]
原有的张量a中有三个二维数组,每个二维数组的shape为(2, 4),在第1维上进行拼接后,cat中每个二维数组的shape为(4, 4),即将b中第i个二维数组拼接到a中第i个二维数组后面,并存于cat的第i个位置。
将axis改为2,cat的shape为(3, 2, 8),值为
[[[1 1 1 1 4 4 4 4]
[1 1 1 1 4 4 4 4]]
[[2 2 2 2 5 5 5 5]
[2 2 2 2 5 5 5 5]]
[[3 3 3 3 6 6 6 6]
[3 3 3 3 6 6 6 6]]]
同理,读者可自行理解。