Pytorch 索引与切片
Indexing
import torch
a = torch. rand( 4 , 3 , 28 , 28 )
print ( "a[0].shape:\t" , a[ 0 ] . shape)
print ( "a[1].shape:\t" , a[ 1 ] . shape)
b = a[ 0 , 0 ] . shape
print ( 'a[0, 0].shape:\t' , a[ 0 , 0 ] . shape)
print ( "a[1, 1].shape:\t" , a[ 0 , 0 ] . shape)
c = a[ 0 , 0 , 2 , 4 ]
print ( "c:\t" , c)
a[ 0 ] . shape: torch. Size( [ 3 , 28 , 28 ] )
a[ 1 ] . shape: torch. Size( [ 3 , 28 , 28 ] )
a[ 0 , 0 ] . shape: torch. Size( [ 28 , 28 ] )
a[ 1 , 1 ] . shape: torch. Size( [ 28 , 28 ] )
c: tensor( 0.6412 )
select first / last N
import torch
a = torch. rand( 4 , 3 , 28 , 28 )
print ( "a.shape:\t" , a. shape)
print ( "a[:2].shape\t" , a[ : 2 ] . shape)
print ( "a[:2,:1,:,:].shape:\t" , a[ : 2 , : 1 , : , : ] . shape)
print ( "a[:2,1:,:,:].shape:\t" , a[ : 2 , 1 : , : , : ] . shape)
print ( "a[:2,-1:,:,:].shape:\t" , a[ : 2 , - 1 : , : , : ] . shape)
a. shape: torch. Size( [ 4 , 3 , 28 , 28 ] )
a[ : 2 ] . shape torch. Size( [ 2 , 3 , 28 , 28 ] )
a[ : 2 , : 1 , : , : ] . shape: torch. Size( [ 2 , 1 , 28 , 28 ] )
a[ : 2 , 1 : , : , : ] . shape: torch. Size( [ 2 , 2 , 28 , 28 ] )
a[ : 2 , - 1 : , : , : ] . shape: torch. Size( [ 2 , 1 , 28 , 28 ] )
select by steps: start:end:step
import torch
a = torch. rand( 4 , 3 , 28 , 28 )
print ( "a[:,:,0:28,0:28:2].shape:\t" , a[ : , : , 0 : 28 : 2 , 0 : 28 : 2 ] . shape)
print ( "a[:,:,::2,::2].shape:\t" , a[ : , : , : : 2 , : : 2 ] . shape)
a[ : , : , 0 : 28 , 0 : 28 : 2 ] . shape: torch. Size( [ 4 , 3 , 14 , 14 ] )
a[ : , : , : : 2 , : : 2 ] . shape: torch. Size( [ 4 , 3 , 14 , 14 ] )
select by specific index:
a = torch. rand( 4 , 3 , 28 , 28 )
b = a. index_select( 0 , torch. tensor( [ 0 , 2 ] ) ) . shape
print ( "b:\t" , b)
c = a. index_select( 1 , torch. tensor( [ 1 , 2 ] ) ) . shape
print ( "c:\t" , c)
d = a. index_select( 2 , torch. arange( 28 ) ) . shape
print ( "d:\t" , d)
e = a. index_select( 2 , torch. arange( 8 ) ) . shape
print ( "e:\t" , e)
b: torch. Size( [ 2 , 3 , 28 , 28 ] )
c: torch. Size( [ 4 , 2 , 28 , 28 ] )
d: torch. Size( [ 4 , 3 , 28 , 28 ] )
e: torch. Size( [ 4 , 3 , 8 , 28 ] )
“…”:代表任意多的维度
import torch
a = torch. rand( 4 , 3 , 28 , 28 )
b = a[ . . . ] . shape
print ( "b:\t" , b)
c = a[ 0 , . . . ] . shape
print ( "c:\t" , c)
d = a[ : , 1 , . . . ] . shape
print ( "d:\t" , d)
e = a[ . . . , : 2 ] . shape
print ( "e:\t" , e)
b: torch. Size( [ 4 , 3 , 28 , 28 ] )
c: torch. Size( [ 3 , 28 , 28 ] )
d: torch. Size( [ 4 , 28 , 28 ] )
e: torch. Size( [ 4 , 3 , 28 , 2 ] )
select by mask:
import torch
x = torch. randn( 3 , 4 )
print ( "x:\t" , x)
mask = x. ge( 0.5 )
print ( "mask:\t" , mask)
y = torch. masked_select( x, mask)
print ( "y:\t" , y)
z = torch. masked_select( x, mask) . shape
print ( "z:\t" , z)
x: tensor( [ [ 0.2149 , - 1.4181 , 0.0112 , 2.2036 ] ,
[ - 0.6523 , 0.1513 , 0.1381 , 0.0905 ] ,
[ - 0.7174 , - 1.4634 , - 0.3409 , - 1.2119 ] ] )
mask: tensor( [ [ False , False , False , True ] ,
[ False , False , False , False ] ,
[ False , False , False , False ] ] )
y: tensor( [ 2.2036 ] )
z: torch. Size( [ 1 ] )
select by flatten index:
import torch
src = torch. tensor( [ [ 4 , 3 , 5 ] , [ 6 , 7 , 8 ] ] )
print ( "src:\t" , src)
a = torch. take( src, torch. tensor( [ 0 , 2 , 5 ] ) )
print ( "a:\t" , a)
src: tensor( [ [ 4 , 3 , 5 ] ,
[ 6 , 7 , 8 ] ] )
a: tensor( [ 4 , 5 , 8 ] )