tensorflow2.1的维度变换

1、Reshape

函数的作用是将tensor变换为参数shape的形式,其中shape为一个列表形式,特殊的一点是列表中可以存在-1

-1代表的含义是不用我们自己指定这一维的大小,函数会自动计算,但列表只能存在一个-1。(如果存在多个-1,就是一个存在多解的方程)

a = tf.random.normal([4, 28,28, 3])

a.shape, a.ndim
Out[74]: (TensorShape([4, 28, 28, 3]), 4)

tf.reshape(a, [4, 784, 3]).shape
Out[76]: TensorShape([4, 784, 3])

tf.reshape(a, [4, -1, 3]).shape
Out[77]: TensorShape([4, 784, 3])

tf.reshape(a, [4, 784*3]).shape
Out[78]: TensorShape([4, 2352])

tf.reshape(a, [4, -1]).shape
Out[79]: TensorShape([4, 2352])

# 恢复
a = tf.random.normal([4, 28,28, 3])

tf.reshape(tf.reshape(a, [4, -1]), [4, 28, 28, 3]).shape
Out[92]: TensorShape([4, 28, 28, 3])

tf.reshape(tf.reshape(a, [4, -1]), [4, 14, 56, 3]).shape
Out[93]: TensorShape([4, 14, 56, 3])

tf.reshape(tf.reshape(a, [4, -1]), [4, 1, 784, 3]).shape
Out[94]: TensorShape([4, 1, 784, 3])

2、转置 tf.transpose

将a进行转置,并且根据perm参数重新排列输出维度.

tf.transpose
(
    a,
    perm=None,
    name='transpose',
    conjugate=False
)

a - 表示的是需要变换的张量

perm - a的新的维度序列

name - 操作的名字,可选的

conjugate - 可选的,设置成True,那么就等于tf.conj(tf.transpose(input)),用的太少了

注:perm-控制转置的操作,perm = [0, 1, 3, 2]表示,把将要转置的第0和第1维度不变,将第2和第3维度进行转置。

a = tf.random.normal((4,3,2,1))

a.shape
Out[96]: TensorShape([4, 3, 2, 1])

tf.transpose(a).shape
Out[98]: TensorShape([1, 2, 3, 4])

tf.transpose(a, perm = [0, 1, 3, 2]).shape
Out[100]: TensorShape([4, 3, 1, 2])

注:pytorch的数据格式一般是[b, 3, h, w], 而tensorFlow的数据格式一般是[b, h, w, 3],所以数据传递时,需要做格式的转换。

a = tf.random.normal([4, 28, 28, 3])  # 一个tensor格式

tf.transpose(a, [0, 2, 1, 3]).shape
Out[103]: TensorShape([4, 28, 28, 3])

tf.transpose(a, [0, 3, 2, 1]).shape
Out[104]: TensorShape([4, 3, 28, 28])

tf.transpose(a, [0, 3, 1, 2]).shape  # 转换成pytorch的数据格式
Out[105]: TensorShape([4, 3, 28, 28])

3、增减一个维度

  • 增加一个维度tf.expand_dims

注:当axis为正数时,在tensor正向对应维度的左边增加一个维度

当axis为负数时,在tensor反向对应维度的右边增加一个维度

a = tf.random.normal([4, 35, 8])

tf.expand_dims(a, axis = 0).shape
Out[107]: TensorShape([1, 4, 35, 8])

tf.expand_dims(a, axis = 3).shape
Out[108]: TensorShape([4, 35, 8, 1])

tf.expand_dims(a, axis = -1).shape
Out[110]: TensorShape([4, 35, 8, 1])

tf.expand_dims(a, axis = -4).shape
Out[111]: TensorShape([1, 4, 35, 8])
  • 减少维度 tf.squeeze

注,只能去掉shape = 1的维度,如[4, 35, 8, 1], 只能去掉最后为1的维度, axis可以指定值为1的维度

tf.squeeze(tf.zeros([1, 2, 1, 1, 3])).shape
Out[112]: TensorShape([2, 3])

tf.squeeze(a, axis = 0).shape
Out[114]: TensorShape([2, 1, 3])

tf.squeeze(a, axis = 2).shape
Out[115]: TensorShape([1, 2, 3])

tf.squeeze(a, axis = 1).shape  # 报错
tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[1], expected a dimension of 1, got 2 [Op:Squeeze]

tf.squeeze(a, axis = -2).shape
Out[117]: TensorShape([1, 2, 3])

tf.squeeze(a, axis = -4).shape
Out[118]: TensorShape([2, 1, 3])
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值