Pytorch学习笔记【3】 --tensor切片
Pytorch笔记目录:点位进入
1. indexing 索引
类似于list的索引操作,tensor也可以使用类似的方法获取tensor中的数值
# create a 4-dim tensor
a = torch.rand(4,3,28,28)
print(a[0].shape)
out:
torch.Size([3, 28, 28])
print(a[0][0].shape)
out:
torch.Size([28, 28])
print(a[0,0,2,4])
out:
tensor(0.2295)
2. 切片
可以通过切片的方式获取tensor中的某一段数据
# select first/last N
a[:2].shape
out:
torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape
out:
torch.Size([2, 1, 28, 28])
a.shape
torch.Size([4, 3, 28, 28])
# select by steps
a[:,:,0:28:2,0:28:2].shape
torch.Size([4, 3, 14, 14])
3. …
…可以表示获取当前位置的所有,下面看一个事例
# ... all
a[...].shape
out:
torch.Size([4, 3, 28, 28])
a[0,...].shape
out:
torch.Size([3, 28, 28])
4. 通过掩码来处理
要实现mask操作首先我们需要创建一个规格和你要处理的tensor相同的tensor,然后再进行处理, 程序会输出所有mask值为1的地方
# select by mask
x = torch.randn(3,4)
mask = x.ge(0.5)
print(torch.masked_select(x,mask))
out:
tensor([1.0916, 0.6544, 0.9824, 1.4880, 0.5094])
flatten index
通过散列的index来获取数值,这种方法的规则就是把多维的向量打平进行计算,当然不会经常用到,3维以上就很难计算了
# select by flatten index
src = torch.tensor([[4,3,5],
[6,7,8]])
print(torch.take(src,torch.tensor([0,2,5])))
out:
tensor([4, 5, 8])