onnx不支持torch.einsum算子,很多嵌入式端平台就更不支持了,下面给出用基本的矩阵计算torch.matmul替代orch.einsum算子的代码。
B = 2
D = 3
H = 4
W1 = 5
W2 = 6
fmap1 = torch.randn(B, D, H, W1)
fmap2 = torch.randn(B, D, H, W2)
corr_einsum = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
print(corr_einsum.shape) # torch.Size([2, 4, 5, 6])
fmap1 = fmap1.reshape(B*D*H,W1,1)
fmap2 = fmap2.reshape(B*D*H,1,W2)
corr = torch.matmul(fmap1, fmap2)
corr = corr.reshape(B,D,H,W1,W2)
corr = torch.sum(corr, dim=1)
print(corr.shape)
print(corr_einsum.equal(corr))
torch.matmul属于基本矩阵操作,一般嵌入式平台都会支持的,如果连基本的矩阵操作不支持,那就建议跟老板提出换芯片平台吧哈哈~