【Pytorch】Tensor的重塑操作

Tensor的操作

Tensor的主要运算操作通常分为四大类:

  1. Reshaping operations(重塑操作)
  2. Element-wise operations(元素操作)
  3. Reduction operations(缩减操作)
  4. Access operations(访问操作)

重塑操作

​ 在重塑操作上,以如下张量为实例进行演示:

import torch
t = torch.tensor([
    [1, 1, 1, 1],
    [2, 2, 2, 2],
    [3, 3, 3, 3]
], dtype=torch.float32)

​ (1)在重塑操作上,我们通常比较关心元素的数量,其具体的查询方式如下:

torch.tensor(t.shape).prod()

显示结果:

tensor(12)

或者使用如下函数:

t.numel()

显示结果:

12

​ (2)在不改变张量的秩的情况下,进行重塑的方式如下:

t.reshape(1, 12)
t.reshape(2, 6)
t.reshape(3, 4)
t.reshape(4, 3)
t.reshape(6, 2)
t.reshape(12, 1)

可以看到形状中各分量值的乘积只有等于元素的个数才能运行,因此元素个数对重塑操作来说十分重要!

(3)在改变张量的秩的情况下,进行重塑的方式如下:

# 这里以秩为三的张量为例(仅举一个例子)
t.reshape(2, 2, 3)

显示结果:

tensor([[[1., 1., 1.],
         [1., 2., 2.]],

        [[2., 2., 3.],
         [3., 3., 3.]]])

(4)通过压缩(squeezing)和解压(unsqueezing)来改变张量的形状

# 正常情况下
print(t.reshape(1, 12))
print(t.reshape(1, 12).shape)

显示结果:

tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
torch.Size([1, 12])
# squeezing
print(t.reshape(1, 12).squeeze())
print(t.reshape(1, 12).squeeze().shape)

显示结果:

tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])
torch.Size([12])

结果分析:

压缩一个张量可以移除所有长度为1的轴!

# unsqueezing
print(t.reshape(1, 12).squeeze().unsqueeze(dim=0))
print(t.reshape(1, 12).squeeze().unsqueeze(dim=0).shape)

显示结果:

tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
torch.Size([1, 12])

结果分析:

解压一个张量会增加一个长度为1的维度!

Flatten

压缩与解压的常见应用——flatten函数

​ 一个flatten的操作是当我们从一个卷积层过渡到一个全连接层时必须在神经网络中发生的操作,flatten函数的代码实现可以如下:

def flatten(t):
    t = t.reshape(1, -1)
    t = t.squeeze()
    return t

函数的参数t可以是任何形状的张量,而在重塑函数reshape中的第二个参数通常给定为-1,-1可以告诉reshape函数根据一个张量中包含的其他值和元素的个数来求出这个值应该是多少!

​ 上述代码考虑的情况较为单一,而通常情况下张量是成批的,这时需要选择性地使用flatten操作。比如在使用CNN时,需要处理的一个张量一般包含多个图像的批,这时需要指定特定的轴来进行flatten操作。

​ 这里需要强调的是,一个神经网络的张量输入通常有四个轴,一个是表示批量大小,一个是表示颜色通道,一个是表示图像高度,一个是表示图像宽度。

​ 现假设有三个张量如下:

t1 = torch.tensor([
    [1, 1, 1, 1],
    [1, 1, 1, 1],
    [1, 1, 1, 1],
    [1, 1, 1, 1]
])

t2 = torch.tensor([
    [2, 2, 2, 2],
    [2, 2, 2, 2],
    [2, 2, 2, 2],
    [2, 2, 2, 2]
])

t3 = torch.tensor([
    [3, 3, 3, 3],
    [3, 3, 3, 3],
    [3, 3, 3, 3],
    [3, 3, 3, 3]
])

这三个张量表示三张4乘以4的图像,用这些张量来创建一个批次传入CNN,即合并成一个更大的张量:

t = torch.stack((t1, t2, t3))
print(t.shape)

显示结果:

torch.Size([3, 4, 4])

接着,为每个张量添加一个单色通道:

t = t.reshape(3, 1, 4, 4)
print(t)

显示结果:

tensor([[[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]],


        [[[2, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 2, 2, 2]]],


        [[[3, 3, 3, 3],
          [3, 3, 3, 3],
          [3, 3, 3, 3],
          [3, 3, 3, 3]]]])

​ 然而在将该张量输入CNN时,我们不希望整批进行flatten,我们只需要把这批张量中的每一张图片进行flatten。

​ 这里先演示对整批进行flatten的四种代码实现(也可以使用上述自定义的flatten函数):

# one
t.reshape(1,-1)[0]
# two
t.reshape(-1)
# three
t.view(t.numel())
# four
t.flatten()

显示结果:

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

​ 接着演示在保持批轴不变的前提下,对每一张图片进行单独flatten操作:

print(t.flatten(start_dim=1).shape)
print(t.flatten(start_dim=1))

显示结果:

torch.Size([3, 16])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])

在调用flatten函数时,指定start_dim参数可以告诉flatten函数当它开始flatten操作时应该从哪个轴开始。
根据上述操作,即可实现对批中每一个图像进行单独flatten的操作了!

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值