pytorch中对tensor操作:分片、索引、压缩、扩充、交换维度、拼接、切割、变形
1 根据维度提取子集
1.0 原始数据情况
import torch
#### 先看一下原始数据
a = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],
[[-1,-2,-3,-4],[-5,-6,-7,-8],[-9,-10,-11,-12]]], dtype=float)
print(a)
# 每个print下面的内容是输出,这里是一个2*3*4的三维矩阵
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[ -1., -2., -3., -4.],
[ -5., -6., -7., -8.],
[ -9., -10., -11., -12.]]], dtype=torch.float64)
1.1 根据第一个维度提取一个子集
#### 根据第一个维度提取第一个元素,结果是一个3*4的矩阵
print(a[0])
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]], dtype=torch.float64)
#### 根据第一个维度提取前两个元素,结果是一个2*3*4的矩阵,其实等于a,因为a在第一个维度也就两个元素
print(a[0:2])
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[ -1., -2., -3., -4.],
[ -5., -6., -7., -8.],
[ -9., -10., -11., -12.]]], dtype=torch.float64)
1.2 根据前两个维度提取一个子集
#### 方法1.2.1
#### 提取第一个维度的第2个元素,再从中提取第二个维度的第3个元素,
#### 结果是一个向量
print(a[1,2])
tensor([ -9., -10., -11., -12.], dtype=torch.float64)
#### 方法1.2.2
#### 提取第一个维度的前两个元素,再从中提取第二个维度的1:2维度(也就是第1个元素)
#### 结果是一个2*1*4的矩阵
print(a[0:2,1:2])
tensor([[[ 5., 6., 7., 8.]],
[[-5., -6., -7., -8.]]], dtype=torch.float64)
#### 方法1.2.3
#### 提取第一个维度的前两个元素,再从中提取第二个维度的第1个元素
#### 注意结果是一个2*4的矩阵
print(a[0:2,1])
tensor([[ 5., 6., 7., 8.],
[-5., -6., -7., -8.]], dtype=torch.float64)
#### 方法1.2.4
#### 提取第二个维度的第3个元素,其他维度的元素全部提取
print(a[:,2]) # 同 print(a[:,2,:]),也就是维度比第二个维度大的下标可以忽略,默认全部提取
tensor([[ 9., 10., 11., 12.],
[ -9., -10., -11., -12.]], dtype=torch.float64)
注意:上面的方法1.2.1最后的维数是1维,和原始数据比下降了两个维度。方法1.2.2和1.2.3想要获得的数据是一致的,但是维度不同。方法1.2.3下降了一个维度。从上面我们可以发现,用n个固定的标量来作为下标,会使得结果比原始数据降低n个维数。比如方法1.2.1种有两个固定标量(1和2),所以维数从3维下降成1维。方法1.2.3种有一个固定标量(1),所以下降了两维。方法1.2.4种有一个固定标量(2),所以下降了一维。
1.3 提取某个特定的元素的值
#### 提取第一个维度的第1个元素,再从中提取第二个维度的第3个元素,再从中提取第三个维度的第2个元素
print(a[0,2,1])
tensor(10., dtype=torch.float64)
#### 将这个值从tensor变量转成python中的数值变量
print(a[0,2,1].item())
10.0
2 对数据进行压缩和扩充:torch.squeeze() 和torch.unsqueeze()
2.1 squeeze()将元素个数只有1的维度压缩掉
#### 先看一下b长什么样,是一个2*1*4的3维矩阵
b = a[:,1:2]
print(b)
tensor([[[ 5., 6., 7., 8.]],
[[-5., -6., -7., -8.]]], dtype=torch.float64)
#### 将第二个维度压缩掉,因为第二个维度的元素个数只有1,所以可以压缩
c = b.squeeze(1) # 等价于 c = torch.squeeze(b,1)
print(c) # 看一下压缩后的结果,是一个2*4的矩阵
tensor([[ 5., 6., 7., 8.],
[-5., -6., -7., -8.]], dtype=torch.float64)
print(b) # 发现b没有变化,也就是torch.squeeze()会返回一个tensor,而不是inplace的操作
tensor([[[ 5., 6., 7., 8.]],
[[-5., -6., -7., -8.]]], dtype=torch.float64)
# tensor.squeeze_()是inplace操作
b.squeeze_(1)
print(b)
tensor([[ 5., 6., 7., 8.],
[-5., -6., -7., -8.]], dtype=torch.float64)
#### 将b中所有只有一个元素的维度都压缩
b = a[:,1:2,2:3]
print(b)
tensor([[[ 7.]],
[[-7.]]], dtype=torch.float64)
print(b.squeeze())
tensor([ 7., -7.], dtype=torch.float64)
2.2 unsqueeze()对数据进行扩充维度
#### 先看一下数据的情况
b = a[0:2,1]
print(b)
tensor([[ 5., 6., 7., 8.],
[-5., -6., -7., -8.]], dtype=torch.float64)
print(b.size())
torch.Size([2, 4])
#### 在第一个维度进行扩充
print(b.unsqueeze(0)) # 等价于print(torch.unsqueeze(b,0))
tensor([[[ 5., 6., 7., 8.],
[-5., -6., -7., -8.]]], dtype=torch.float64)
print(b.unsqueeze(0).size())
torch.Size([1, 2, 4])
#### 在第二个维度进行扩充
# 等于将方法1.2.3的结果在第2个维度上进行扩充,变成和方法1.2.2是一样的结果
print(b.unsqueeze(1))
tensor([[[ 5., 6., 7., 8.]],
[[-5., -6., -7., -8.]]], dtype=torch.float64)
print(b.unsqueeze(1).size())
torch.Size([2, 1, 4])
#### 在第三个维度进行扩充
print(b.unsqueeze(2))
tensor([[[ 5.],
[ 6.],
[ 7.],
[ 8.]],
[[-5.],
[-6.],
[-7.],
[-8.]]], dtype=torch.float64)
print(torch.unsqueeze(b,2))
torch.Size([2, 4, 1])
3 对数据维度进行交换:tensor.permute()
permute可以对数据维度进行交换,数据本身不变
#### 先看一下原始数据
print(a)
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[ -1., -2., -3., -4.],
[ -5., -6., -7., -8.],
[ -9., -10., -11., -12.]]], dtype=torch.float64)
print(a.size())
torch.Size([2, 3, 4])
#### 将原始数据a的第2维给新数据b的第1维;第1维给第二维;第3维给第3维
b= a.permute(1,0,2)
print(b)
tensor([[[ 1., 2., 3., 4.],
[ -1., -2., -3., -4.]],
[[ 5., 6., 7., 8.],
[ -5., -6., -7., -8.]],
[[ 9., 10., 11., 12.],
[ -9., -10., -11., -12.]]], dtype=torch.float64)
print(b.size())
torch.Size([3, 2, 4])
#### a本身不变
print(a)
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[ -1., -2., -3., -4.],
[ -5., -6., -7., -8.],
[ -9., -10., -11., -12.]]], dtype=torch.float64)
4 对数据进行拼接:torch.cat(), torch.stack()
4.1 cat
指定维度,利用cat对多个数据进行拼接,拼接前后的总维数不变
#### 先看一下数据
a = torch.tensor([[1,2,3,4],[5,6,7,8]])
b = torch.tensor([[9,10,11,12],[13,14,15,16]])
print(a)
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
print(b)
tensor([[ 9, 10, 11, 12],
[13, 14, 15, 16]])
#### 根据第一维度拼接
print(torch.cat((a,b),0)) # 等价于print(torch.cat((a,b))
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
#### 根据第二维度拼接
print(torch.cat((a,b),1))
tensor([[ 1, 2, 3, 4, 9, 10, 11, 12],
[ 5, 6, 7, 8, 13, 14, 15, 16]])
#### 还可以拼接多个
print(torch.cat((a,b,a))
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16],
[ 1, 2, 3, 4],
[ 5, 6, 7, 8]])
4.2 stack
指定维度,对多个数据进行拼接,拼接后总维数增加1
#### 按照第一个维度堆叠
print(torch.stack((a,b),0)) # 等价于torch.stack((a,b))
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12],
[13, 14, 15, 16]]])
print(torch.stack((a,b)).size()) # 2维变成3维
torch.Size([2, 2, 4])
#### 按照第二个维度堆叠
print(torch.stack((a,b),1))
tensor([[[ 1, 2, 3, 4],
[ 9, 10, 11, 12]],
[[ 5, 6, 7, 8],
[13, 14, 15, 16]]])
#### 按照第三个维度堆叠
print(torch.stack((a,b),2))
tensor([[[ 1, 9],
[ 2, 10],
[ 3, 11],
[ 4, 12]],
[[ 5, 13],
[ 6, 14],
[ 7, 15],
[ 8, 16]]])
#### 和cat一样,可以对多个数据进行stack
print(torch.stack((a,b,a),2))
tensor([[[ 1, 9, 1],
[ 2, 10, 2],
[ 3, 11, 3],
[ 4, 12, 4]],
[[ 5, 13, 5],
[ 6, 14, 6],
[ 7, 15, 7],
[ 8, 16, 8]]])
5 对数据进行切割:torch.split()
利用split对数据进行切割,split的第二个参数可以是一个数字也可以是一个list,第三个参数是维度。切割后的数据维度和原始数据一致
#### 先看一下数据
a = torch.arange(1,16).reshape(5,3)
print(a)
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[13, 14, 15]])
#### 均匀切割。根据第一个维度将a切割,每块包含2个元素,最后不足的就有多少输出多少
x = torch.split(a,2,0) # 获得3块结果,每块结果的维度和原始数据一致
print(x[0])
tensor([[1, 2, 3],
[4, 5, 6]])
print(x[1])
tensor([[ 7, 8, 9],
[10, 11, 12]])
print(x[2]) # 因为最后一块数据不足,所以只有一行,而不是两行
tensor([[13, 14, 15]])
#### 均匀切割。根据第二个维度进行切割,每块包含2个元素
x = torch.split(a,2,1)
print(x[0])
tensor([[ 1, 2],
[ 4, 5],
[ 7, 8],
[10, 11],
[13, 14]])
print(x[1])
tensor([[ 3],
[ 6],
[ 9],
[12],
[15]])
#### 自定义切割。根据第二个维度切割,一共切割成两块,第一个块包含1个元素(也就是1列),第二块包含2个元素(也就是2列)
x = torch.split(a,[1,2],1)
print(x[0])
tensor([[ 1],
[ 4],
[ 7],
[10],
[13]])
print(x[1])
tensor([[ 2, 3],
[ 5, 6],
[ 8, 9],
[11, 12],
[14, 15]])
6 对数据进行变形:tensor.reshape()
利用reshape来改变数据的形状和维数,类似于view,但是比view更加强大,特别是在数据不是连续(比如转置过后)的时候也适用。
#### 先看一下数据
a = torch.arange(1,16).reshape(5,3)
print(a)
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[13, 14, 15]])
#### 将a变成3*5的矩阵
b = a.reshape(3,5)
print(b)
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]])
#### 将a变成1*15的矩阵
b = a.reshape(1,15)
print(b)
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
#### 将a变成元素个数为15的向量
b = a.reshape(15)
print(b)
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
# 注意这里是一维数组,上面的1*15矩阵是二维数组