Pytorch的数组操作
常用函数
- np.stack
arrays = [np.random.randn(3, 4) for _ in range(10)]
np.stack(arrays, axis=0).shape # (10, 3, 4)
- torch.cat(seq,dim =0)
沿着dim连接seq中的tensor, 所有的tensor必须有相同的size或为empty, 其相反的操作为 torch.split() 和torch.chunk()
import torch
a = torch.tensor([[1,2],
[3,4]])
b = torch.tensor([[2,4],
[6,8]])
e = torch.tensor([[3,9],
[18,27]])
c = torch.cat((a,b,e),0)
d1 = torch.split(c, 3, dim=0)
d2 = torch.chunk(c, 4, dim=0)
print '\nc:\n',c
print '\nd1:\n',d1
print '\nd2:\n',d2
output:
c:
tensor([[ 1, 2],
[ 3, 4],
[ 2, 4],
[ 6, 8],
[ 3, 9],
[18, 27]])
d1:
(tensor([[1, 2],
[3, 4],
[2, 4]]), tensor([[ 6, 8],
[ 3, 9],
[18, 27]]))
d2:
(tensor([[1, 2],
[3, 4]]), tensor([[2, 4],
[6, 8]]), tensor([[ 3, 9],
[18, 27]]))
- torch.stack
#同上, .cat 和 .stack的区别在于 cat会增加现有维度的值,可以理解为续接,stack会新加增加一个维度,可以
理解为叠加 - torch.squeeze
- [np.newaxis,:]
- torch.full
torch.full((2,3),1)
output
tensor([[1., 1., 1.],
[1., 1., 1.]])