pytorch基本数据类型
- 32位浮点型:torch.FloatTensor。pyorch.Tensor()默认的就是这种类型。
- 64位整型:torch.LongTensor。
- 32位整型:torch.IntTensor。
- 16位整型:torch.ShortTensor。
- 64位浮点型:torch.DoubleTensor。
view用法
相当于numpy中resize()的功能,但是用法可能不太一样
把原先tensor中的数据按照行优先的顺序排成一个一维的数据
(这里应该是因为要求地址是连续存储的),然后按照参数组合成其他维度的tensor。
例1:
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=torch.Tensor([1,2,3,4,5,6])
print(a.view(1,6))
print(b.view(1,6))
输出:
tensor([[1., 2., 3., 4., 5., 6.]])
例2:
a=torch.Tensor([[[1,2,3],[4,5,6]]])
print(a.view(3,2)) # a.view(-1,2)
相当于就是从1,2,3,4,5,6顺序的拿数组来填充需要的形状,输出:
tensor([[1., 2.],
[3., 4.],
[5., 6.]])
gather用法
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合
例:
a1 = torch.randint(0, 30, (2, 3, 5))
print(a1)
print(a1.shape)
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
print(a1.size()==index.size())
b1 = a1.gather(1,index) # dim=1时,沿着列的方向选择元素
print(b1)
输出:
tensor([[[20, 19, 9, 4, 3],
[ 6, 29, 24, 3, 9],
[14, 10, 11, 22, 1]],
[[ 7, 10, 3, 1, 2],
[ 2, 2, 29, 11, 16],
[19, 16, 12, 29, 22]]])
torch.Size([2, 3, 5])
True
tensor([[[20, 29, 11, 4, 1],
[20, 19, 9, 4, 3],
[ 6, 29, 24, 3, 9]],
[[ 2, 16, 12, 29, 22],
[ 7, 10, 3, 1, 2],
[19, 16, 12, 29, 22]]])