TORCH.SEARCHSORTED
日期:2022年8月5日
pytorch版本: 1.11.0
官方文档的链接:https://pytorch.org/docs/stable/generated/torch.searchsorted.html
因为第一次看到这个也不知道什么意思,numpy中也没接触过这个函数,二者的效果应该是一样的。
这里需要注意一下就是sorted_sequence
需要是一个排好序的递增序列,不然可能得到的效果不是你需要的那种。
函数的作用主要就是:
返回一个和values
一样大小的tensor,其中的元素是在sorted_sequence
中满足下列条件的索引i
图中的m
n
l
都是指的维度,x
指当前values
中所看的那个值的下标,i
为最后返回的索引的值。
所以不管是多少维,都是看最里面一层来比较,如果right
为True
,就是左边是可以等于,不然就是右边可以等于,如果为 False
,则返回找到的第一个合适的位置。如果为 True
,则返回最后一个此类索引。
大概就可以理解为寻找一个区间,看能不能满足条件,然后输出右端点的索引i
,然后可能会出现0和超出最大索引的数
官方例子如下,可以再理解一下:
>>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])
>>> sorted_sequence
tensor([[ 1, 3, 5, 7, 9],
[ 2, 4, 6, 8, 10]])
>>> values = torch.tensor([[3, 6, 9], [3, 6, 9]])
>>> values
tensor([[3, 6, 9],
[3, 6, 9]])
>>> torch.searchsorted(sorted_sequence, values)
tensor([[1, 3, 4],
[1, 2, 4]])
>>> torch.searchsorted(sorted_sequence, values, side='right')
tensor([[2, 3, 5],
[1, 3, 4]])
>>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9])
>>> sorted_sequence_1d
tensor([1, 3, 5, 7, 9])
>>> torch.searchsorted(sorted_sequence_1d, values)
tensor([[1, 3, 4],
[1, 3, 4]])
如果不是递增的例子:
>>> sorted_sequence_1d = torch.tensor([9, 7, 5, 3, 1])
>>> values = torch.tensor([3, 6, 9])
>>> torch.searchsorted(sorted_sequence_1d, values)
tensor([0, 5, 5])
参考链接:
https://blog.csdn.net/qq_35037684/article/details/125275305