def feature_fusion(self, x_spatial, x_spectral, pool='mean'):
x_spatial_transpose = x_spatial.permute(0, 2, 1) # 2 1024 26
x_spectral_transpose = x_spectral.permute(0, 2, 1) # 2 1024 104
spe2spa = torch.matmul(x_spatial, x_spectral_transpose) # 26*1024 1024*104 26*104
spa2spe = torch.matmul(x_spectral, x_spatial_transpose) # 104*1024 1024*26 104*26
spa_softmax = torch.softmax(spe2spa, dim=-1) # 26*104
spe_softmax = torch.softmax(spa2spe, dim=-1) # 104*26
X_spatial = torch.matmul(spa_softmax, x_spectral) + x_spatial # 26*104 104*1024 26*1024
X_spectral = torch.matmul(spe_softmax, x_spatial) + x_spectral
X_spatial = X_spatial.mean(dim=1) if pool == 'mean' else X_spatial[:, 0]
X_spectral = X_spectral.mean(dim=1) if pool == 'mean' else X_spectral[:, 0]
return X_spatial.add(X_spectral)
论文在这,目前我没达到想要的效果,可能是softmax那有问题。先把代码放这保存一下。