tf.concat()的理解和使用

一:原型

concat(values, axis, name=“concat”)。简单理解即将传入的values(若干shape完全一样的N维张量)在指定的维度axis(0<= axis <= N-1)上进行拼接,并返回拼接后的张量。

二:代码分析

1:一维张量

a = tf.constant([1,2])
b = tf.constant([3, 4])
c = tf.concat(values=[a, b], axis=0)
with tf.Session() as sess:
    print(sess.run(c))

如上面代码所示,定义了两个一维张量a和b,axis的可取值此时只能是0,c为拼接后的结果,运行可知,c=[1 2 3 4]。

a = tf.constant([1])
b = tf.constant([2])
c = tf.constant([3])
d = tf.constant([4])
cat = tf.concat(values=[a, b, c, d], axis=0)
with tf.Session() as sess:
    print(sess.run(cat))

同理,以上代码运行后结果为[1 2 3 4]

2:二维张量

对于二维张量而言,axis的可选值包括0和1。

a = tf.constant([
    [1, 2, 3],
    [4, 5, 6]
])
b = tf.constant([
    [7, 8, 9],
    [3, 5, 8]
])

cat = tf.concat(values=[a, b], axis=0)
with tf.Session() as sess:
    print(sess.run(cat))

如上代码所示,定义了两个shape为(2, 3)的张量a和b,并在第0维上进行concat操作,运行程序知cat的shape为(4, 3),cat的值为:

[[1 2 3]
 [4 5 6]
 [7 8 9]
 [3 5 8]]

可知,axis=0时即在第0维,也就是在二维张量的行上进行了concat操作。将axis改为1后。运行程序知cat的shape为(2, 6),值为:

[[1 2 3 7 8 9]
 [4 5 6 3 5 8]]

可知axis=1时相当于对张量a的列进行了扩展。

3:三维张量

对于三维张量而言,axis的可选值包括0、1、2。

a = tf.constant([
    [
        [1, 1, 1, 1],
        [1, 1, 1, 1]
    ],
    [
        [2, 2, 2, 2],
        [2, 2, 2, 2]
    ],
    [
        [3, 3, 3, 3],
        [3, 3, 3, 3]
    ],
])
b = tf.constant([
    [
        [4, 4, 4, 4],
        [4, 4, 4, 4]
    ],
    [
        [5, 5, 5, 5],
        [5, 5, 5, 5]
    ],
    [
        [6, 6, 6, 6],
        [6, 6, 6, 6]
    ],
])

cat = tf.concat(values=[a, b], axis=0)
with tf.Session() as sess:
    print(sess.run(cat))

程序中定义了两个shape为(3, 2, 4)的三维张量,并在第0维上进行concat操作,运行程序后,cat的shape为(6, 2, 4),值为:

[[[1 1 1 1]
  [1 1 1 1]]

 [[2 2 2 2]
  [2 2 2 2]]

 [[3 3 3 3]
  [3 3 3 3]]

 [[4 4 4 4]
  [4 4 4 4]]

 [[5 5 5 5]
  [5 5 5 5]]

 [[6 6 6 6]
  [6 6 6 6]]]

对于三维张量A而言,第0维即表示A中含有多少个二维张量B,由以上结果可知,在0维上拼接,相当于把b中所有的二维张量直接添加到a原有的二维张量后面。
将axis改为1,得到的cat的shape为(3, 4, 4),值为

[[[1 1 1 1]
  [1 1 1 1]
  [4 4 4 4]
  [4 4 4 4]]

 [[2 2 2 2]
  [2 2 2 2]
  [5 5 5 5]
  [5 5 5 5]]

 [[3 3 3 3]
  [3 3 3 3]
  [6 6 6 6]
  [6 6 6 6]]]

原有的张量a中有三个二维数组,每个二维数组的shape为(2, 4),在第1维上进行拼接后,cat中每个二维数组的shape为(4, 4),即将b中第i个二维数组拼接到a中第i个二维数组后面,并存于cat的第i个位置。
将axis改为2,cat的shape为(3, 2, 8),值为

[[[1 1 1 1 4 4 4 4]
  [1 1 1 1 4 4 4 4]]

 [[2 2 2 2 5 5 5 5]
  [2 2 2 2 5 5 5 5]]

 [[3 3 3 3 6 6 6 6]
  [3 3 3 3 6 6 6 6]]]

同理,读者可自行理解。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值