tensorflow2.1中tensor的合并与分割

1、拼接操作tf.concat

其约束条件:你要创建的axis是可以不同,但其他的维度是要相同的。

用法:

tf.concat
(
  values,
  axis,
  name='concat'
)

axis代表要在哪个维度拼接:axis=0代表在第0个维度拼接,axis=1代表在第1个维度拼接…

示例:

import tensorflow as tf

a = tf.ones([4, 35, 8])
b = tf.ones([2, 35, 8])

c = tf.concat([a, b], axis = 0)
c.shape
Out[6]: TensorShape([6, 35, 8])

a = tf.ones([4, 32, 8])
b = tf.ones([4, 3, 8])

tf.concat([a, b], axis = 1).shape
Out[9]: TensorShape([4, 35, 8])

2、创建新维度的拼接tf.stack

约束条件:要求被拼接的所有维度相同。
用法:

tf.stack
(
    values,
    axis=0,
    name='stack'
)

假设输入是由N个shape为(A,B,C)的tensor组成的一个list

如果沿着axis ==0 进行拼接,那么拼接后的输入的tensor的shape为(N,A,B,C)

如果沿着axis ==1 进行拼接,那么拼接后的输入的tensor的shape为(A,N,B,C)

示例:

a = tf.ones([4, 35, 8])
b = tf.ones([4, 35, 8])

tf.concat([a, b], axis = 0).shape
Out[16]: TensorShape([8, 35, 8])

tf.stack([a, b], axis = 0).shape
Out[17]: TensorShape([2, 4, 35, 8])

tf.stack([a, b], axis = 3).shape
Out[18]: TensorShape([4, 35, 8, 2])

3、stack的反操作降维分割tf.unstack

用法:

tf.unstack
(
    value,
    num=None,
    axis=0,
    name=’unstack’
)

参数说明:

  • value: 一个将要被降维的维度大于0的张量
  • num: 整数。指定的维度axis的长度。如果设置为None(默认值),将自动求值。
  • axis: 整数.以轴axis指定的维度来降维 默认是第一个维度即axis=0。支持负数。取值范围为[-R, R)
  • name: 这个操作的名字

示例:

a = tf.ones([4, 35, 8])
b = tf.ones([4, 35, 8])

c = tf.stack([a, b])
c.shape
Out[22]: TensorShape([2, 4, 35, 8])

aa, bb = tf.unstack(c, axis = 0)
aa.shape, bb.shape
Out[24]: (TensorShape([4, 35, 8]), TensorShape([4, 35, 8]))

res = tf.unstack(c, axis = 3)

res[0].shape, res[7].shape
Out[26]: (TensorShape([2, 4, 35]), TensorShape([2, 4, 35]))

4、灵活性更强的分割tf.split()

用法:

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)

如果num_or_size_splits是一个整数,就等分axis指定的维度,如果num_or_size_splits是一个整数列表,则按列表分割axia指定的维度:
示例:

c = tf.ones([2, 4, 35, 8])
res = tf.unstack(c, axis = 3)
len(res)
Out[31]: 8

res = tf.split(c, axis = 3, num_or_size_splits = 2)
len(res)
Out[33]: 2
res[0].shape, res[1].shape
Out[35]: (TensorShape([2, 4, 35, 4]), TensorShape([2, 4, 35, 4]))

res = tf.split(c, axis = 3, num_or_size_splits = [2, 2, 4])

res[0].shape, res[2].shape
Out[37]: (TensorShape([2, 4, 35, 2]), TensorShape([2, 4, 35, 4]))
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值