permute():将tensor的维度换位。
permute(多维数组,[维数换位顺序])
例1:
a=rand(2,3,4); %a是一个三维数组,各维的长度分别为:2,3,4
permute(a,[2,1,3]) %实现了交换第一维和第二维,a变成3*2*4的矩阵
例2:
来自行人重识别的一段代码:
def forward(self, x):
b, c, h, w = x.size()
if self.use_spatial:
# spatial attention
#print(x.shape)
theta_xs = self.theta_spatial(x) #[8, 256, 64, 32]->[8, 32, 64, 32]
#print(theta_xs.shape)
phi_xs = self.phi_spatial(x) #[8, 256, 64, 32]->[8, 32, 64, 32]
#print(phi_xs.shape)
theta_xs = theta_xs.view(b, self.inter_channel, -1) # [8, 32, 64, 32]->[8, 32, 64*32]
#print(theta_xs.shape)
theta_xs = theta_xs.permute(0, 2, 1) # [8, 32, 64*32]->[8 64*32 32] 64*32=2048
#print(theta_xs.shape)
phi_xs = phi_xs.view(b, self.inter_channel, -1) # 8 32 64*32
#print(phi_xs.shape)
Gs = torch.matmul(theta_xs, phi_xs)# 8 2048 2048
#print(Gs.shape)