Pytorch 张量维度变化是比较常用和重要的操作,本文主要介绍几种常用的维度变化方法:
1. view()
方法释义:返回当前张量的视图张量;
Pytorch 允许一个 tensor 成为现有 tensor 的一个视图,视图张量与其基础张量共享同样的底层数据。视图张量能够避免明显的数据拷贝,因而能够让我们快速且内存高效地进行张量重塑、切片和逐元素操作。
所以可以通过 t.view() 方法来获取 tensor t 的视图张量。
示例:
a = torch.rand(4, 4)
print(a.shape)
b = a.view(2, 8)
print(b.shape)
print(a.storage().data_ptr()) # a 张量的内存数据指针地址
print(b.storage().data_ptr()) # b 张量的内存数据指针地址,其实是一样的
b[0][0] = 3.14
print(a[0][0]) # 内存共享,所以改变 b,相应 a 数据也会变化
c = a.view(2, 7) # 报错,因为视图张量数据共享,故要 size 必须一样,见输出结果报错信息
# 输出结果
torch.Size([4, 4])
torch.Size([2, 8])
140495468325312
140495468325312
tensor(3.1400)
RuntimeError: shape '[2, 7]' is invalid for input of size 16
2. torch.reshape(input, shape) → [Tensor]
方法释义:返回与 input 张量一样数据和大小的,且与给定 shape 一样的张量。如果可能,返回的是input 张量的视图,否则返回的是其拷贝。
示例:
a = torch.arange(4.)
print(a)
b = torch.reshape(a, (2, 2)) # 返回的是与 shape(2,2)一样的 tensor,数据与 a 一样
print(b)
c = torch.tensor([[0, 1], [2, 3]])
d = torch.reshape(c, (-1,)) # 如果 shape 是 (-1,) ,则返回的与 input 一样数量的一维张量
print(d)
# 输出结果
tensor([0., 1., 2., 3.])
tensor([[0., 1.],
[2., 3.]])
tensor([0, 1, 2, 3])
3. torch.squeeze(input, dim=None) → [Tensor]
方法释义:将 input 张量中所有维度数据为 1 的维度给移除掉。指定了 dim,如果 dim 对应维度的值不为 1 ,则保持不变,为 1 则移除该维度。
示例:
a = torch.ones(2, 1, 2, 1, 3)
print(a.shape)
b = torch.squeeze(a) # 将所有维度中数值为 1 的移除掉
print(b.shape)
c = torch.squeeze(a, 0) # 第一个维度数值不为 1,则不变动
print(c.shape)
d = torch.squeeze(a, 3) # 第四个维度值为 1,则移除第四个维度
print(d.shape)
# 输出结果
torch.Size([2, 1, 2, 1, 3])
torch.Size([2, 2, 3])
torch.Size([2, 1, 2, 1, 3])
torch.Size([2, 1, 2, 3])
4. torch.unsqueeze(input, dim) → [Tensor]
方法释义:在给定的 dim 维度位置插入一个新的维度,维度数值为 1,dim 的范围在 [-dim()-1, dim()+1),包首不包尾
示例:
a = torch.ones(2, 3, 4, 5) # 如 a 的维度为 4,那么指定的 dim 值在 [-5, 4]
b = torch.unsqueeze(a, 0)
print(b.shape)
c = torch.unsqueeze(a, 4)
print(c.shape)
d = torch.unsqueeze(a, -1)
print(d.shape)
e = torch.unsqueeze(a, -5)
print(e.shape)
#输出结果
torch.Size([1, 2, 3, 4, 5])
torch.Size([2, 3, 4, 5, 1])
torch.Size([2, 3, 4, 5, 1])
torch.Size([1, 2, 3, 4, 5])
5. Tensor.expand( *sizes) → [Tensor]
方法释义:返回张量的新视图,其某个维度 size 扩展到更大的 size,如果当前维度 size 为 -1 ,表示当前维度 size 保持不变。
Tensor也可以扩展到更多的维度,新的会追加在最前面。对于新维度,大小不能设置为 -1;
扩展张量不会分配新内存,而只会在现有张量上创建一个新视图。任何大小为1的维度都可以扩展为任意值,而无需分配新内存。
参数:
*sizes (torch.Size or [int] – 需要扩展的 size 值
示例:
a = torch.ones(1, 32, 1, 1)
b = a.expand(4, 32, 5, 3) # 维度为 1 的 size 可以扩展成什么任意的 size
print(b.shape)
c = a.expand(-1, 32, -1, -1) # -1 表示对应的维度 size 不变,但如果 32 改成 33 则会报错
print(c.shape)
d = a.expand(-1, 32, -1, 5) # 最后一个维度 size 扩展成 5
print(d.shape)
e = a.expand(5, 1, 32, 1, 1) # 可以扩展新的维度,但只会放到最前面,不能放到后面(会报错)且不能设置为 -1
print(e.shape)
# 输出结果
torch.Size([4, 32, 5, 3])
torch.Size([1, 32, 1, 1])
torch.Size([1, 32, 1, 5])
torch.Size([5, 1, 32, 1, 1])
6. Tensor.repeat( *sizes) → [Tensor]
方法释义:根据指定维度复制张量,与 expand 不同的是,该方法会拷贝原张量的数据
参数:
sizes (torch.Size or [int] – 指定维度复制的次数
示例:
a = torch.ones(1, 32, 1, 1)
print(a.storage().data_ptr())
b = a.expand(4, 32, 5, 3)
print(b.storage().data_ptr()) # expand 操作后,张量的内存地址没变
c = a.repeat(4, 1, 2, 3) # 重复的次数,返回为 torch.Size([4, 32, 2, 3])
print(c.shape)
print(c.storage().data_ptr()) # repeat 操作后,张量的内存地址会改变
d = a.repeat(4, 1, 2, 3, 2) # 指定的重复次数数组长度超过张量的维度,则会将张量从前面扩展成一样长度的维度再复制,a 会先扩展成 torch.Size([1, 1,32, 2, 3]), 再复制在 torch.Size([4, 1, 64, 3, 2])
print(d.shape)
print(d.storage().data_ptr())
# 输出结果
140220581413824
140220581413824
torch.Size([4, 32, 2, 3])
140220501380608
torch.Size([4, 1, 64, 3, 2])
140220501807104
7. torch.transpose(input, dim0, dim1) → [Tensor]
方法释义:返回 input 张量的转置,dim0 与 dim 1 交换位置
参数:
- input ([Tensor] 输入的张量
- dim0 ([int] 第一个要转置的维度
- dim1 ([int] 第二个要转置的维度
示例:
a = torch.rand(2, 3)
print(a)
b = torch.transpose(a, 0, 1) # 第一维度与第二维度转换位置
print(b)
print(torch.t(a)) # 也可以使用 torch.t(a)
# 输出结果
tensor([[0.6242, 0.2934, 0.4182],
[0.2461, 0.0797, 0.9801]])
tensor([[0.6242, 0.2461],
[0.2934, 0.0797],
[0.4182, 0.9801]])
tensor([[0.6242, 0.2461],
[0.2934, 0.0797],
[0.4182, 0.9801]])
8. torch.permute(input, dims) → [Tensor]
方法释义:返回重新排列的张量
参数:
- input ([Tensor] 要重新排列的张量
- dims (tuple of python:int) 需要重排的维度索引数组
示例:
a = torch.rand(2, 3, 5)
b = torch.permute(a, (2, 0, 1)) # 按维度索引重新排列
print(b.shape)
c = torch.permute(a, (2, 0, )) # 给定的 dims 数组长度和 a 维度不一致会报错
# 输出结果
torch.Size([5, 2, 3])
RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 2