今天看网上DCN的解读,顺带看了代码,决定这次把unsqueeze和gather搞明白。
unsqueeze的作用跟字面意思一样,与squeeze作用相反,就是扩维,请看代码。
>> > import torch
>> > a = torch. tensor( [ 1 , 2 , 3 , 4 , 5 ] )
>> > a. shape
torch. Size( [ 5 ] )
可以看到a的维度为1.
>> > b= a. unsqueeze( dim= 0 )
>> > b
tensor( [ [ 1 , 2 , 3 , 4 , 5 ] ] )
在第0维对a扩维,可以看到a的维度变为了2,再看一下shape.
>> > b. shape
torch. Size( [ 1 , 5 ] )
可以看到在原先的第0维变为了第1维。
>> > c= b. unsqueeze( dim= - 1 )
>> > c
tensor( [ [ [ 1 ] ,
[ 2 ] ,
[ 3 ] ,
[ 4 ] ,
[ 5 ] ] ] )
在最后一维对b进行扩维,再看一下shape。
>> > c. shape
torch. Size( [ 1 , 5 , 1 ] )
可以看到c的维度增加了,在最后一维,再向c加一个维度.
>> > d= c. unsqueeze( dim= - 1 )
>> > d
tensor( [ [ [ [ 1 ] ] ,
[ [ 2 ] ] ,
[ [ 3 ] ] ,
[ [ 4 ] ] ,
[ [ 5 ] ] ] ] )
>> > d. shape
torch. Size( [ 1 , 5 , 1 , 1 ] )
再验证下squeeze。
>> > e = d. squeeze( 0 )
>> > e. shape
torch. Size( [ 5 , 1 , 1 ] )
>> > e = d. squeeze( 1 )
>> > e. shape
torch. Size( [ 1 , 5 , 1 , 1 ] )
>> > e = d. squeeze( - 1 )
>> > e. shape
torch. Size( [ 1 , 5 , 1 ] )
>> > e = d. squeeze( - 2 )
>> > e. shape
torch. Size( [ 1 , 5 , 1 ] )
可以看到squeeze只能将多余的维度降维,包含数据的维度是没有用的。再看一下gather。
gather的作用是按index读取input中的数据。index要求与input维度相同。
>> > a
tensor( [ [ 1 , 2 , 3 , 4 ] ,
[ 2 , 3 , 4 , 5 ] ,
[ 3 , 4 , 5 , 6 ] ,
[ 4 , 5 , 6 , 7 ] ] )
>> > index
tensor( [ [ 0 ] ,
[ 1 ] ,
[ 2 ] ,
[ 3 ] ] )
>> > a. shape
torch. Size( [ 4 , 4 ] )
>> > index. shape
torch. Size( [ 4 , 1 ] )
首先给定一个输入a,维度为2.index的维度也为2,两者维度要相同,而且若要gather对input第i维进行操作,那么index在第i维的维度应大于等于1.(翻译文档)因为若对input第i维进行操作,input的第i维定不为空,这样index在该维度就不能没有数值,也就是大于等于1.
>> > out = torch. gather( a, 0 , index)
>> > out
tensor( [ [ 1 ] ,
[ 2 ] ,
[ 3 ] ,
[ 4 ] ] )
>> > out = torch. gather( a, 1 , index)
>> > out
tensor( [ [ 1 ] ,
[ 3 ] ,
[ 5 ] ,
[ 7 ] ] )