Pytorch 合并与分割
cat Statistics about scores
import torch
a = torch. rand( 4 , 32 , 8 )
print ( "a.shape:\t" , a. shape)
b = torch. rand( 5 , 32 , 8 )
print ( "b.shape:\t" , b. shape)
c = torch. cat( [ a, b] , dim = 0 ) . shape
print ( "c:\t" , c)
a. shape: torch. Size( [ 4 , 32 , 8 ] )
b. shape: torch. Size( [ 5 , 32 , 8 ] )
c: torch. Size( [ 9 , 32 , 8 ] )
import torch
a1 = torch. rand( 4 , 3 , 32 , 32 )
a2 = torch. rand( 5 , 3 , 32 , 32 )
a3 = torch. rand( 4 , 1 , 32 , 32 )
c = torch. cat( [ a1, a3] , dim= 1 ) . shape
print ( "c:\t" , c)
a4 = torch. rand( 4 , 3 , 16 , 32 )
d = torch. cat( [ a1, a4] , dim= 2 ) . shape
print ( "d:\t" , d)
c: torch. Size( [ 4 , 4 , 32 , 32 ] )
d: torch. Size( [ 4 , 3 , 48 , 32 ] )
stack: ccreate new
import torch
a1 = torch. rand( 4 , 3 , 32 , 32 )
a2 = torch. rand( 4 , 3 , 32 , 32 )
b = torch. stack( [ a1, a2] , dim= 2 ) . shape
print ( "b:\t" , b)
c = torch. rand( 32 , 8 )
d = torch. rand( 32 , 8 )
e = torch. stack( [ c, d] , dim= 0 ) . shape
print ( "e:\t" , e)
b: torch. Size( [ 4 , 3 , 2 , 32 , 32 ] )
e: torch. Size( [ 2 , 32 , 8 ] )
Split:
import torch
a = torch. rand( 32 , 8 )
b = torch. rand( 32 , 8 )
print ( "a.shape:\t" , a. shape)
c = torch. stack( [ a, b] , dim= 0 )
print ( "c.shape:\t" , c. shape)
aa, bb = c. split( [ 1 , 1 ] , dim= 0 )
print ( "aa:\t" , aa. shape)
print ( "bb:\t" , bb. shape)
cc, dd = c. split( 1 , dim= 0 )
print ( "cc:\t" , cc. shape)
print ( "dd:\t" , dd. shape)
a. shape: torch. Size( [ 32 , 8 ] )
c. shape: torch. Size( [ 2 , 32 , 8 ] )
aa: torch. Size( [ 1 , 32 , 8 ] )
bb: torch. Size( [ 1 , 32 , 8 ] )
cc: torch. Size( [ 1 , 32 , 8 ] )
dd: torch. Size( [ 1 , 32 , 8 ] )