Numpy(高维数组) |Pytorch(张量) 索引读操作笔记

Pytorch 索引张量与 Numpy 索引多维数组类似,但是在 Pytorch 官网文档中并未找到详细的索引读操作说明,而是直接引用了 Numpy 的索引操作说明,因此本文主要以 Numpy 索引高维数组为例。

When accessing the contents of a tensor via indexing, PyTorch follows Numpy behaviors that basic indexing returns views, while advanced indexing returns a copy. Assignment via either basic or advanced indexing is in-place.

索引读高维数组或者张量遵循 Python x[obj] 语法,本文中称 x 为被索引对象, obj 为“索引对象”,根据“索引对象”的不同取值,有“基础索引”, “高级索引”, “字段访问(field access)”三种形式, 本文主要记录“基础索引”与“高级索引”。
 

基础索引

当索引对象是元组,并且元组仅由整数、切片、维度索引工具组成时,遵守基础索引规则。

  • 注意 v1, v2, v3 是元组 (v1, v2, v3) 的等价形式。
  • 维度索引工具(Dimensional Indexing Tools)包括 Ellipsisnewaxis。对象 Ellipsis 在索引对象中表示“所需的若干个切片”,对象 newaxisNone 等价, 表示在指定位置插入一个新的维度。
     

单一元素索引

当需要获取被索引对象的某一个元素值时,索引规则与标准 Python 序列对象索引操作一致,只是将 1 维 Python 序列索引操作拓展到多维,此时被索引对象的维度” 必须与 “索引对象元组的长度相等。

np.arange(10)[0]  # return 0; 被索引数组为1维,索引元组长度为1
np.arange(10).reshape(2, 5)[0, 1]  # return 1, 被索引数组为2维, 索引元组长度需为2

torch.arange(10)[0]  # return tensor(0)
torch.arange(10).reshape(2,5)[0,1]  # return tensor(0)

注意在此情形中,Numpy 的表现与 Pytorch 有差异,Numpy 返回数组数据类型的单一数值,而 Pytorch 返回的是张量。

type(np.arange(5)[0]), type(torch.arange(5)[0])  # <class 'numpy.int64'>, <class 'torch.Tensor'>

索引写操作时 Numpy 返回的应该也是高维数组,而非单一数值,否则赋值声明应该会抛出异常,但以下操作确是正常的:

x = np.arange(5)
x[0] = 99  # x -> array([99,  1,  2,  3,  4])

另外在基础索引情形中,如果“索引对象元组”的长度小于“被索引对象”的维度,在对齐“被索引对象维度” 与 “索引元组” 的条件下,“索引元组”缺失的部分默认赋值:切片值

If the number of objects in the selection tuple is less than N, then : is assumed for any subsequent dimensions.

# 以下两者等价, return array([0, 1, 2, 3, 4])
np.arange(10).reshape(2, 5)[0]
np.arange(10).reshape(2, 5)[0, :]

 

非单一元素索引

基础索引中的非单一元素索引情形,“索引对象元组” 必然包含 “切片对象”, 高维数组或者张量的切片将标准 Python 切片从单维拓展到多维度。

Basic slicing extends Python’s basic concept of slicing to N dimensions. The standard rules of sequence slicing apply to basic slicing on a per-dimension basis (including using a step index).

All arrays generated by basic slicing are always views of the original array.

如果“索引对象元组”中每个元素均为切片,则返回结果的维度数量与“被索引对象”的维度保持一致,但如果“索引对象元组”某个位置的值为一个整数,则该整数值对应的维度在返回结果中将会抛弃掉——标准 Python 序列索引操作中,单一索引值取出单一对象,而切片操作总是返回序列。

An integer, i, returns the same values as i:i+1 except the dimensionality of the returned object is reduced by 1. In particular, a selection tuple with the p-th element an integer (and all other entries is slice) returns the corresponding sub-array with dimension N - 1.

np.arange(10).reshape(2, 5)[:, 0:1]  # return array([[0], [5]]),返回数组维度为2
np.arange(10).reshape(2, 5)[:, 0] # return array([0, 5]), 返回数组维度为1

torch.arange(10).reshape(2, 5)[:, 0:1].shape  # return torch.Size([2, 1])
torch.arange(10).reshape(2, 5)[:, 0]  # return torch.Size([2])

注意 np.arange(10).reshape(2, 5)[:, [0]] 并非“基础索引”的情形,因为“索引对象元组”中的第二个元素是一个列表对象。
 

高级索引

Advanced indexing is triggered when the selection object, obj, is a non-tuple sequence object, an ndarray (of data type integer or bool), or a tuple with at least one sequence object or ndarray (of data type integer or bool). There are two types of advanced indexing: integer and Boolean.

Advanced indexing always returns a copy of the data (contrast with basic slicing that returns a view).

常见的高级索引包括以下形式:

  • 情形一:索引对象不是一个元组序列,而是一个高维数组或者张量,其中布尔型比整数型更常见。
  • 情形二:索引对象是一个元组序列,并且元组序列完全由整数型高维数组或者整数型张量组成。
  • 情形三:索引对象是一个元组序列,并且元组序列完全由列表序列组成。
  • 情形四:索引对象是一个元组序列,元组序列不仅包含高维整数型数组或者高维整数型张量,还包括序列对象。
  • 情形五:索引对象是一个元组序列,元组序列不仅包含高维整数型数组或者高维整数型张量,还包括整数标量。
  • 情形六:索引对象是一个元组序列,元组序列不仅包含高维整数型数组或者高维整数型张量,还包括整数标量、序列对象。
     

情形二:完全由整数型高维数组或者张量组成的元组序列

此种情形下,元组序列中的所有高维数组或者张量首先会进行“广播(broadcast)”,使得所有的高维数组或者张量具有统一的 shape,然后类似 Python zip 函数,将元组序列中的高维数组或者张量中的每一个元素按维度顺序依次组装起来——每一个组装后的元素类似一个基础索引,所有组件最终形成一个迭代器。然后不断迭代从“被索引对象”中取出元素,最终组装成返回结果。如果元组序列的长度小于“被索引对象”维度的长度,对齐元组序列元素与“被索引对象”的维度,缺失的部分相当于赋值:切片。

Advanced indices always are broadcast and iterated as one.

When the index consists of as many integer arrays as dimensions of the array being indexed, the indexing is straightforward, but different from slicing.

In general, the shape of the resultant array will be the concatenation of the shape of the index array (or the shape that all the index arrays were broadcast to) with the shape of any unused dimensions (those not indexed) in the array being indexed.

 

情形三:完全由列表组成元组索引序列

“广播机制” 会扩展所有列表到统一的 shape,其它逻辑与情形二一致。
 

情形四:由高维数组、张量、列表序列组成元组索引序列

“广播机制” 将所有的高维数组、张量、列表序列扩展到统一的 shape, 其它逻辑与情形二一致。
 

情形五:由整数型高维数组或者张量,以及整数标量组成元组序列

当元组序列中某一位置由整数标量占据时,“广播机制”将其它高维数组或者张量拓展到统一的 shape,并且高维数组或者张量内的每一元素均与同一整数标量进行组装。

The broadcasting mechanism permits index arrays to be combined with scalars for other indices. The effect is that the scalar value is used for all the corresponding values of the index arrays.
 

情形六:由整数型高维数组或者张量、以及整数标量、列表序列组成元组索引序列

与情形五类似,只是广播机制考虑的范围从“高维数组、张量” 拓展到 “高维数组、张量、列表序列”。
 

情形一:索引对象为单个布尔型高维数组或者张量

A single boolean index array is practically identical to x[obj.nonzero()]

此种情形下,x[obj] 等价于 x[np.nonzero(obj)] 。首先看 np.nonzero 方法的说明:

Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.

np.nonzero(obj) 返回一个长度为 obj 维度数量的元组,元组中的每一个元素是一个整数型数组,每一整数型数组中的数值对应 obj 在此维度上元素值为 True 的索引值。

np.nonzero(np.arange(24).reshape(2, 3, 4) > 0)

# 返回结果如下:
(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), 
 array([1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))

得到了这些整数型数组,后续按照情形二的逻辑进行索引取值即可。

np.arange(24).reshape(2, 3, 4)[np.arange(24).reshape(2, 3, 4) > 0]

np.arange(24).reshape(2, 3, 4)[(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
                               np.array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]),
                               np.array([1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))]

# 返回结果均为array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23])

注意 torch.nonzeronp.nonzero 在返回值上有些差异, torch.nonzero 会返回一个张量,张量的 shape 为 (obj张量中True 值的数量,obj张量的维度),本质就是将 np.nonzero 的返回结果按照维度顺序一一组装起来,如下所示:

torch.nonzero(torch.arange(24).reshape(2, 3, 4) > 0)
# 返回结果如下:
tensor([[0, 0, 1],
        [0, 0, 2],
        [0, 0, 3],
        [0, 1, 0],
        [0, 1, 1],
        [0, 1, 2],
        [0, 1, 3],
        [0, 2, 0],
        [0, 2, 1],
        [0, 2, 2],
        [0, 2, 3],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 2],
        [1, 0, 3],
        [1, 1, 0],
        [1, 1, 1],
        [1, 1, 2],
        [1, 1, 3],
        [1, 2, 0],
        [1, 2, 1],
        [1, 2, 2],
        [1, 2, 3]])

由于 torch.nonzeronp.nonzero 在返回值上的差异,因此使用 torch.arange(24).reshape(2, 3, 4) [torch.nonzero(torch.arange(24).reshape(2, 3, 4) > 0)] 会抛出异常。说明在 Pytorch 中, x[obj] 不能等价于 x[torch.nonzero(obj)], 需要增加额外的处理逻辑——将 torch.nonzero 返回的单个张量按照第二个维度分解成多个张量组成的元组序列,如下所示:

y = torch.arange(24).reshape(2, 3, 4) > 0
indices = y.nonzero()
_indices = []
for _ in range(indices.shape[1]):
    _indices.append(torch.LongTensor([item[_] for item in indices]))
print(torch.arange(24).reshape(2, 3, 4)[tuple(_indices)])

当然 torch.arange(24).reshape(2, 3, 4)[torch.arange(24).reshape(2, 3, 4) > 0] 也能得到一样的结果。

也允许将布尔型高维数组或者张量与整数型多维数组或者张量、列表序列、整数标量、切片、维度选择工具混合使用,只是并不常见,以 Numpy 为例,此种情况下的索引规则遵守:

In general if an index includes a Boolean array, the result will be identical to inserting obj.nonzero() into the same position and using the integer array indexing mechanism described above. x[ind_1, boolean_array, ind_2] is equivalent to x[(ind_1,) + boolean_array.nonzero() + (ind_2,)].

 

基础索引 + 高级索引(复杂)

当“索引对象”中同时包含“高级索引”以及“基础索引”时,可以按照 slice, ellipsis , newaxis (不包括整数型标量)将高级索引分割成作为相互独立的部分,然后将所有的高级索引部分的结果组装起来。开发中可以借助打印高维数组或者张量的 shape,查看 shape 是否与预期一致

The easiest way to understand a combination of multiple advanced indices may be to think in terms of the resulting shape. There are two parts to the indexing operation, the subspace defined by the basic indexing (excluding integers) and the subspace from the advanced indexing part. Two cases of index combination need to be distinguished:

  • The advanced indices are separated by a slice, Ellipsis or newaxis. For example x[arr1, :, arr2].
  • The advanced indices are all next to each other. For example x[…, arr1, arr2, :] but not x[arr1, :, 1] since 1 is an advanced index in this regard.

 

实例

在使用 Conditional Bert 做训练数据集增强任务时,没有采用多项式分布随机采用词库,而是想在前 topN 概率值的范围内随机采样,具体实现如下:

import torch

def randomly_sample_from_topN(logits: torch.FloatTensor, topN=50, num_samples=2):
   logits = torch.softmax(logits, dim=2)  # (bsz, seq_len, vocab_size)
   _, index = torch.sort(logits, dim=2, descending=True)
   
   # 第一个维度上的整数索引数组
   int_arr_0 = torch.arange(batch_size)[:, None, None].repeat(1, seq_len, num_samples)
   
   # 第二个维度上的整数索引数组
   int_arr_1 = torch.arange(seq_len)[None, :, None].repeat(batch_size, 1, num_samples)

	# 第三个维度上的整数索引数组
   int_arr_2 = torch.randint(0, topN, size=(batch_size, seq_len, num_samples))
   
   predictions = index[int_arr_0, int_arr_1, int_arr_2]
   assert predictions.shape == (batch_size, seq_len, num_samples)

   # (batch_size, num_samples, seq_len)
   predictions = torch.transpose(predictions, 1, 2)
   return preditions

 

参考

  1. Indexing on ndarrays
  2. tensor view
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值