torch.argsort()函数组合的效果
前段时间在看何凯明大神MAE的代码的时候发现了下面一段代码:
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
这个其实是在对输入序列进行随机采样的。方式是:
1.假设我们有一个需要采样的序列X[1,2,3,4,5],先创建一个随机值的noise tensor
2.根据noise的值从小到大排序,得到ids_shuffle,这个是noise的值从小到大排序后对应的下标序列。
例如noise:[5,1,6,9,8] ,则 ids_shuffle:[1,0,2,4,3]
3.以ids_shuffle的前n(以3为例)个值作为序列X的下标进行采样得到结果,得到X_ : [2,1,3]。
但是这段代码中有一行让我很好奇,那就是ids_restore = torch.argsort(ids_shuffle, dim=1)这段代码是在做什么呢?
从名字上看,取名为ids_restore ,意思是恢复ids。即恢复下标。但是我很好奇的是,将排序过的序列再排序怎么就能恢复下标了呢?
需要注意的一点是:X_可以认为是从X中随机抽取了N个数,抽取的方式可以是任意的顺序。比如我抽取 X[3] X[0] X[4]
经过思考最终得到的答案如下(还是以上面举的例子来说明):
1.在获得ids_shuffle后,ids_shuffle里面的值是我们要对X进行采样的一个index随机序列。通俗的讲就是,我现在要根据ids_shuffle中的每个index值去获取X。例如ids_shuffle:[1,0,2,4,3],则我们得到的新的序列X_就是 [X[1] , X[0] , X[2] ,…]。
2.这个时候如果对ids_shuffle的值再进行一个排序得到ids_restore,我们得到的ids_restore结果是什么呢? 因为ids_shuffle的值记录的是随机创建的X_采样子序列中每个位置的元素对应原X序列的位置,ids_restore获取的过程可以分为两步理解:(1),对ids_shuffle的值从小到大排序,即原始序列X从0-N的排序,就是X的原始序列位置。(2),获取每个原始位置在ids_shuffle中的index。也就是说如果我们是根据ids_shuffle来获取随机采样的子序列X_,那么ids_restore记录的就是我原始X中按照顺序X[0] X[1] X[2]… 在ids_shuffle中的位置。 例如这里我的X[0]在ids_shuffle中的位置为1, X[1]在ids_shuffle中的位置为0, X[2]的位置为2 ,X[3]的位置为4,X[4]的位置为3.ids_restore:[1,0,2,4,3]
3.那么如果以后的子序列都通过ids_shuffle构建的话,因为它是随机采样,没有位置顺序信息,当我们要将子序列恢复到输入图片patch原来序列的顺序的时候就可以使用ids_restore。按照ids_restore的每个值作为子序列的取值下标得到的序列就是按照原图patch的大小顺序得到的序列了/