Pytorch 基础之维度变化

Pytorch 张量维度变化是比较常用和重要的操作,本文主要介绍几种常用的维度变化方法:

1. view()

方法释义:返回当前张量的视图张量;

Pytorch 允许一个 tensor 成为现有 tensor 的一个视图,视图张量与其基础张量共享同样的底层数据。视图张量能够避免明显的数据拷贝,因而能够让我们快速且内存高效地进行张量重塑、切片和逐元素操作。

所以可以通过 t.view() 方法来获取 tensor t 的视图张量。

示例:

a = torch.rand(4, 4)
print(a.shape)
b = a.view(2, 8)
print(b.shape)
print(a.storage().data_ptr())  # a 张量的内存数据指针地址
print(b.storage().data_ptr())  # b 张量的内存数据指针地址,其实是一样的
b[0][0] = 3.14
print(a[0][0])    # 内存共享,所以改变 b,相应 a 数据也会变化
c = a.view(2, 7)  # 报错,因为视图张量数据共享,故要 size 必须一样,见输出结果报错信息

# 输出结果
torch.Size([4, 4])
torch.Size([2, 8])
140495468325312
140495468325312
tensor(3.1400)
RuntimeError: shape '[2, 7]' is invalid for input of size 16
2. torch.reshape(inputshape) → [Tensor]

方法释义:返回与 input 张量一样数据和大小的,且与给定 shape 一样的张量。如果可能,返回的是input 张量的视图,否则返回的是其拷贝。

示例:

a = torch.arange(4.)
print(a)
b = torch.reshape(a, (2, 2))  # 返回的是与 shape(2,2)一样的 tensor,数据与 a 一样
print(b)

c = torch.tensor([[0, 1], [2, 3]])
d = torch.reshape(c, (-1,))  # 如果 shape 是 (-1,) ,则返回的与 input 一样数量的一维张量
print(d)

# 输出结果
tensor([0., 1., 2., 3.])
tensor([[0., 1.],
        [2., 3.]])
tensor([0, 1, 2, 3])
3. torch.squeeze(inputdim=None) → [Tensor]

方法释义:将 input 张量中所有维度数据为 1 的维度给移除掉。指定了 dim,如果 dim 对应维度的值不为 1 ,则保持不变,为 1 则移除该维度。

示例:

a = torch.ones(2, 1, 2, 1, 3)
print(a.shape)
b = torch.squeeze(a)  # 将所有维度中数值为 1 的移除掉
print(b.shape)
c = torch.squeeze(a, 0)  # 第一个维度数值不为 1,则不变动
print(c.shape)
d = torch.squeeze(a, 3)  # 第四个维度值为 1,则移除第四个维度
print(d.shape)

# 输出结果
torch.Size([2, 1, 2, 1, 3])
torch.Size([2, 2, 3])
torch.Size([2, 1, 2, 1, 3])
torch.Size([2, 1, 2, 3])
4. torch.unsqueeze(inputdim) → [Tensor]

方法释义:在给定的 dim 维度位置插入一个新的维度,维度数值为 1,dim 的范围在 [-dim()-1, dim()+1),包首不包尾

示例:

a = torch.ones(2, 3, 4, 5)   # 如 a 的维度为 4,那么指定的 dim 值在 [-5, 4]
b = torch.unsqueeze(a, 0)
print(b.shape)
c = torch.unsqueeze(a, 4)
print(c.shape)
d = torch.unsqueeze(a, -1)
print(d.shape)
e = torch.unsqueeze(a, -5)
print(e.shape)

#输出结果
torch.Size([1, 2, 3, 4, 5])
torch.Size([2, 3, 4, 5, 1])
torch.Size([2, 3, 4, 5, 1])
torch.Size([1, 2, 3, 4, 5])
5. Tensor.expand( *sizes) → [Tensor]

方法释义:返回张量的新视图,其某个维度 size 扩展到更大的 size,如果当前维度 size 为 -1 ,表示当前维度 size 保持不变。

Tensor也可以扩展到更多的维度,新的会追加在最前面。对于新维度,大小不能设置为 -1;

扩展张量不会分配新内存,而只会在现有张量上创建一个新视图。任何大小为1的维度都可以扩展为任意值,而无需分配新内存。

参数:
*sizes (torch.Size or [int] – 需要扩展的 size 值

示例:

a = torch.ones(1, 32, 1, 1)
b = a.expand(4, 32, 5, 3)    # 维度为 1 的 size 可以扩展成什么任意的 size
print(b.shape)
c = a.expand(-1, 32, -1, -1) # -1 表示对应的维度 size 不变,但如果 32 改成 33 则会报错
print(c.shape)
d = a.expand(-1, 32, -1, 5)  # 最后一个维度 size 扩展成 5
print(d.shape)
e = a.expand(5, 1, 32, 1, 1) # 可以扩展新的维度,但只会放到最前面,不能放到后面(会报错)且不能设置为 -1 
print(e.shape)

# 输出结果
torch.Size([4, 32, 5, 3])
torch.Size([1, 32, 1, 1])
torch.Size([1, 32, 1, 5])
torch.Size([5, 1, 32, 1, 1])
6. Tensor.repeat( *sizes) → [Tensor]

方法释义:根据指定维度复制张量,与 expand 不同的是,该方法会拷贝原张量的数据

参数:
sizes (torch.Size or [int] – 指定维度复制的次数

示例:

a = torch.ones(1, 32, 1, 1)
print(a.storage().data_ptr())
b = a.expand(4, 32, 5, 3)
print(b.storage().data_ptr())   # expand 操作后,张量的内存地址没变

c = a.repeat(4, 1, 2, 3)        # 重复的次数,返回为 torch.Size([4, 32, 2, 3])
print(c.shape)
print(c.storage().data_ptr())   # repeat 操作后,张量的内存地址会改变
d = a.repeat(4, 1, 2, 3, 2)     # 指定的重复次数数组长度超过张量的维度,则会将张量从前面扩展成一样长度的维度再复制,a 会先扩展成 torch.Size([1, 1,32, 2, 3]), 再复制在 torch.Size([4, 1, 64, 3, 2])
print(d.shape)
print(d.storage().data_ptr())

# 输出结果
140220581413824
140220581413824
torch.Size([4, 32, 2, 3])
140220501380608
torch.Size([4, 1, 64, 3, 2])
140220501807104
7. torch.transpose(inputdim0dim1) → [Tensor]

方法释义:返回 input 张量的转置,dim0 与 dim 1 交换位置

参数:

  • input ([Tensor] 输入的张量
  • dim0 ([int] 第一个要转置的维度
  • dim1 ([int] 第二个要转置的维度

示例:

a = torch.rand(2, 3)
print(a)
b = torch.transpose(a, 0, 1) # 第一维度与第二维度转换位置
print(b)
print(torch.t(a))            # 也可以使用 torch.t(a)

# 输出结果
tensor([[0.6242, 0.2934, 0.4182],
        [0.2461, 0.0797, 0.9801]])
tensor([[0.6242, 0.2461],
        [0.2934, 0.0797],
        [0.4182, 0.9801]])
tensor([[0.6242, 0.2461],
        [0.2934, 0.0797],
        [0.4182, 0.9801]])

8. torch.permute(inputdims) → [Tensor]

方法释义:返回重新排列的张量

参数:

  • input ([Tensor] 要重新排列的张量
  • dims (tuple of python:int) 需要重排的维度索引数组

示例:

a = torch.rand(2, 3, 5)
b = torch.permute(a, (2, 0, 1))   # 按维度索引重新排列
print(b.shape)
c = torch.permute(a, (2, 0, ))   # 给定的 dims 数组长度和 a 维度不一致会报错

# 输出结果
torch.Size([5, 2, 3])
RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 2

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值