前言 有时候会面对需要把数据进行维度转换的情况, 比如本来512*N*W*H(BNWH)的维度需要转换为512*(N*W*H)的一个output和(N*W*H)*512的一个output,然后将两者进行矩阵乘法。 即(NHW)*512 X 512*(NWH) = (NHW)*(NHW), 然后再和初始的512*N*W*H进行矩阵乘法,结果仍旧是512*N*W*H,常用在一些non-local conv block中。 代码 import torch import numpy as np input = torch.randn(2,3,4,4