Tensor:索引操作

索引操作

Tensor支持与numpy.ndarray类似的索引操作,语法上也类似,下面通过一些例子,讲解常用的索引操作。如无特殊说明,索引出来的结果与原tensor共享内存,也即修改一个,另一个会跟着修改。

   In [31]:

a = t.randn(3, 4)
a
Out[31]:
tensor([[ 1.1741,  1.4335, -0.8156,  0.7622],
        [-0.6334, -1.4628, -0.7428,  0.0410],
        [-0.6551,  1.0258,  2.0572,  0.3923]])

In [32]:

a[0] # 第0行(下标从0开始)

Out[32]:

tensor([ 1.1741,  1.4335, -0.8156,  0.7622])

In [33]:

a[:, 0] # 第0列

Out[33]:

tensor([ 1.1741, -0.6334, -0.6551])

In [34]:

a[0][2] # 第0行第2个元素,等价于a[0, 2]

Out[34]:

tensor(-0.8156)

In [35]:

a[0, -1] # 第0行最后一个元素

Out[35]:

tensor(0.7622)

In [36]:

a[:2] # 前两行

Out[36]:

tensor([[ 1.1741,  1.4335, -0.8156,  0.7622],
        [-0.6334, -1.4628, -0.7428,  0.0410]])

In [37]:

a[:2, 0:2] # 前两行,第0,1列

Out[37]:

tensor([[ 1.1741,  1.4335],
        [-0.6334, -1.4628]])

In [38]:

print(a[0:1, :2]) # 第0行,前两列 
print(a[0, :2]) # 注意两者的区别:形状不同
Out[38]:
tensor([[1.1741, 1.4335]])
tensor([1.1741, 1.4335])

In [39]:

# None类似于np.newaxis, 为a新增了一个轴
# 等价于a.view(1, a.shape[0], a.shape[1])
a[None].shape

Out[39]:

torch.Size([1, 3, 4])

In [40]:

a[None].shape # 等价于a[None,:,:]

Out[40]:

torch.Size([1, 3, 4])

In [41]:

a[:,None,:].shape

Out[41]:

torch.Size([3, 1, 4])

In [42]:

a[:,None,:,None,None].shape

Out[42]:

torch.Size([3, 1, 4, 1, 1])

In [43]:

a > 1 # 返回一个ByteTensor

Out[43]:

tensor([[1, 1, 0, 0],
        [0, 0, 0, 0],
        [0, 1, 1, 0]], dtype=torch.uint8)

In [44]:

a[a>1] # 等价于a.masked_select(a>1)
# 选择结果与原tensor不共享内存空间

Out[44]:

tensor([1.1741, 1.4335, 1.0258, 2.0572])

In [45]:

a[t.LongTensor([0,1])] # 第0行和第1行

Out[45]:

tensor([[ 1.1741,  1.4335, -0.8156,  0.7622],
        [-0.6334, -1.4628, -0.7428,  0.0410]])

其它常用的选择函数如表3-2所示。

表3-2常用的选择函数

函数功能
index_select(input, dim, index)在指定维度dim上选取,比如选取某些行、某些列
masked_select(input, mask)例子如上,a[a>0],使用ByteTensor进行选取
non_zero(input)非0元素的下标
gather(input, dim, index)根据index,在dim维度上选取数据,输出的size与index一样

gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下:

out[i][j] = input[index[i][j]][j]  # dim=0
out[i][j] = input[i][index[i][j]]  # dim=1

三维tensor的gather操作同理,下面举几个例子。

In [46]:

a = t.arange(0, 16).view(4, 4)
a

Out[46]:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

In [47]:

# 选取对角线的元素
index = t.LongTensor([[0,1,2,3]])
a.gather(0, index)

Out[47]:

tensor([[ 0,  5, 10, 15]])

In [48]:

# 选取反对角线上的元素
index = t.LongTensor([[3,2,1,0]]).t()
a.gather(1, index)

Out[48]:

tensor([[ 3],
        [ 6],
        [ 9],
        [12]])

In [49]:

# 选取反对角线上的元素,注意与上面的不同
index = t.LongTensor([[3,2,1,0]])
a.gather(0, index)

Out[49]:

tensor([[12,  9,  6,  3]])

In [50]:

# 选取两个对角线上的元素
index = t.LongTensor([[0,1,2,3],[3,2,1,0]]).t()
b = a.gather(1, index)
b

Out[50]:

tensor([[ 0,  3],
        [ 5,  6],
        [10,  9],
        [15, 12]])

gather相对应的逆操作是scatter_gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。

out = input.gather(dim, index)
-->近似逆操作
out = Tensor()
out.scatter_(dim, index)

In [51]:

# 把两个对角线元素放回去到指定位置
c = t.zeros(4,4)
c.scatter_(1, index, b.float())

Out[51]:

tensor([[ 0.,  0.,  0.,  3.],
        [ 0.,  5.,  6.,  0.],
        [ 0.,  9., 10.,  0.],
        [12.,  0.,  0., 15.]])

对tensor的任何索引操作仍是一个tensor,想要获取标准的python对象数值,需要调用tensor.item(), 这个方法只对包含一个元素的tensor适用

In [52]:

a[0,0] #依旧是tensor)

Out[52]:

tensor(0)

In [53]:

a[0,0].item() # python float

Out[53]:

0

In [54]:

d = a[0:1, 0:1, None]
print(d.shape)
d.item() # 只包含一个元素的tensor即可调用tensor.item,与形状无关
torch.Size([1, 1, 1])

Out[54]:

0

In [55]:

# a[0].item()  ->
# raise ValueError: only one element tensors can be converted to Python scalars

高级索引

PyTorch在0.2版本中完善了索引操作,目前已经支持绝大多数numpy的高级索引1。高级索引可以看成是普通索引操作的扩展,但是高级索引操作的结果一般不和原始的Tensor共享内存。


  1. https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing

In [56]:

x = t.arange(0,27).view(3,3,3)
x

Out[56]:

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]])

In [57]:

x[[1, 2], [1, 2], [2, 0]] # x[1,1,2]和x[2,2,0]

Out[57]:

tensor([14, 24])

In [58]:

x[[2, 1, 0], [0], [1]] # x[2,0,1],x[1,0,1],x[0,0,1]

Out[58]:

tensor([19, 10,  1])

In [59]:

x[[0, 2], ...] # x[0] 和 x[2]

Out[59]:

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]])
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值