pytorch常用的一些函数
- unsqueeze()函数介绍:在指定维度上增加维度。
例如:a = t.arange(0,6) a.view(2,3) #a的维度为(2,3) a.unsqueeze(1) #增加第二个维度 ,a的维度为(2,1,3)
- tensor.expand_as():把一个tensor变成和函数括号内一样形状的tensor,用法与expand()类似
例如:x = torch.tensor([[1],[2],[3]]) #x.size为(3,1) x.expand(3,4) #x.size为(3,4)
- gather()函数:沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
torch.gather(input, dim, index, out=None) → Tensor
例如:
观察它的输出结果:b = torch.Tensor([[1,2,3],[4,5,6]]) print b index_1 = torch.LongTensor([[0,1],[2,0]]) index_2 = torch.LongTensor([[0,1,1],[0,0,0]]) print torch.gather(b, dim=1, index=index_1) print torch.gather(b, dim=0, index=index_2)
1 2 3 4 5 6 [torch.FloatTensor of size 2x3] 1 2 6 4 [torch.FloatTensor of size 2x2] 1 5 6 1 2 3 [torch.FloatTensor of size 2x3]
- permute(dims):将tensor的维度换位。
例如:import torch import numpy as np a=np.array([[[1,2,3],[4,5,6]]]) unpermuted=torch.tensor(a) print(unpermuted.size()) # ——> torch.Size([1, 2, 3]) permuted=unpermuted.permute(2,0,1) print(permuted.size()) # ——> torch.Size([3, 1, 2])
- cat()函数:对数据沿着某一维度进行拼接。cat后数据的总维数不变
例如:下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。
结果:import torch torch.manual_seed(1) x = torch.randn(2,3) y = torch.randn(1,3) print(x,y)
0.6614 0.2669 0.0617 0.6213 -0.4519 -0.1661 [torch.FloatTensor of size 2x3] -1.5228 0.3817 -1.0276 [torch.FloatTensor of size 1x3]