pytorch转onnx时对einsum进行替换
1、einsum简介
2、替换
# a = torch.einsum('bhnm,bdhm->bdhn', prob, value)
prob = prob.unsqueeze(1)
value = value.unsqueeze(3)
a = prob*value
a = torch.sum(a, dim=-1)
- 输出:b d h n **
- 输入:b # h n m
- 输入:b d h # m
如上所示把输入输出维度对齐,将两个输入在标记#的地方加个为1的维度,然后相乘,得到的结果在标记**的地方求和,转成的onnx模型如下所示:
3、测试
>>> a = torch.randn(1,2,3,4)
>>> a
tensor([[[[ 0.3094, 0.3956, -0.2181, -1.7280],
[ 0.2459, -0.3665, 0.5259, -0.4842],
[ 0.3226, 0.1653, -0.0322, -0.2071]],
[[ 0.1408, 1.0664, 0.0070, 2.4841],
[-1.2721, -0.0768, 0.7270, -0.2795],
[-0.3427, 1.2243, -0.8096, -0.2815]]]])
>>> b = torch.randn(1,2,3,5)
>>> b
tensor([[[[ 0.1462, 0.5758, -0.5803, -0.6292, -1.6647],
[ 0.6382, 0.2753, -0.7068, -1.4866, 0.0245],
[-0.8264, 0.3745, 1.7032, -0.8747, -0.6040]],
[[ 0.9619, 0.2243, -0.9576, 1.1650, 0.0468],
[-1.3150, 0.1407, -0.0818, 0.8451, 1.4826],
[-0.3857, -0.4206, 0.3140, -0.6252, 1.3214]]]])
>>> c = torch.einsum('bdhn,bdhm->bhnm',a,b)
>>> c
tensor([[[[ 0.1807, 0.2097, -0.3144, -0.0307, -0.5085],
[ 1.0837, 0.4669, -1.2507, 0.9934, -0.6086],
[-0.0251, -0.1240, 0.1198, 0.1454, 0.3634],
[ 2.1370, -0.4378, -1.3760, 3.9811, 2.9929]],
[[ 1.8297, -0.1113, -0.0698, -1.4406, -1.8800],
[-0.1329, -0.1117, 0.2653, 0.4799, -0.1228],
[-0.6203, 0.2471, -0.4311, -0.1675, 1.0907],
[ 0.0585, -0.1727, 0.3651, 0.4836, -0.4262]],
[[-0.1344, 0.2650, 0.4419, -0.0680, -0.6477],
[-0.6088, -0.4530, 0.6660, -0.9100, 1.5180],
[ 0.3388, 0.3285, -0.3089, 0.5342, -1.0504],
[ 0.2797, 0.0408, -0.4410, 0.3571, -0.2469]]]])
>>> d = a.reshape(1,2,3,4,1)
>>> d
tensor([[[[[ 0.3094],
[ 0.3956],
[-0.2181],
[-1.7280]],
[[ 0.2459],
[-0.3665],
[ 0.5259],
[-0.4842]],
[[ 0.3226],
[ 0.1653],
[-0.0322],
[-0.2071]]],
[[[ 0.1408],
[ 1.0664],
[ 0.0070],
[ 2.4841]],
[[-1.2721],
[-0.0768],
[ 0.7270],
[-0.2795]],
[[-0.3427],
[ 1.2243],
[-0.8096],
[-0.2815]]]]])
>>> e = b.reshape(1,2,3,1,5)
>>> e
tensor([[[[[ 0.1462, 0.5758, -0.5803, -0.6292, -1.6647]],
[[ 0.6382, 0.2753, -0.7068, -1.4866, 0.0245]],
[[-0.8264, 0.3745, 1.7032, -0.8747, -0.6040]]],
[[[ 0.9619, 0.2243, -0.9576, 1.1650, 0.0468]],
[[-1.3150, 0.1407, -0.0818, 0.8451, 1.4826]],
[[-0.3857, -0.4206, 0.3140, -0.6252, 1.3214]]]]])
>>> f = d*e
>>> f
tensor([[[[[ 4.5235e-02, 1.7816e-01, -1.7956e-01, -1.9469e-01, -5.1511e-01],
[ 5.7831e-02, 2.2777e-01, -2.2955e-01, -2.4890e-01, -6.5854e-01],
[-3.1885e-02, -1.2558e-01, 1.2656e-01, 1.3723e-01, 3.6309e-01],
[-2.5261e-01, -9.9492e-01, 1.0027e+00, 1.0872e+00, 2.8766e+00]],
[[ 1.5694e-01, 6.7712e-02, -1.7380e-01, -3.6558e-01, 6.0128e-03],
[-2.3387e-01, -1.0091e-01, 2.5900e-01, 5.4478e-01, -8.9603e-03],
[ 3.3563e-01, 1.4481e-01, -3.7171e-01, -7.8185e-01, 1.2859e-02],
[-3.0900e-01, -1.3332e-01, 3.4221e-01, 7.1980e-01, -1.1839e-02]],
[[-2.6661e-01, 1.2083e-01, 5.4949e-01, -2.8221e-01, -1.9486e-01],
[-1.3662e-01, 6.1917e-02, 2.8157e-01, -1.4461e-01, -9.9848e-02],
[ 2.6572e-02, -1.2043e-02, -5.4765e-02, 2.8127e-02, 1.9421e-02],
[ 1.7111e-01, -7.7549e-02, -3.5265e-01, 1.8112e-01, 1.2506e-01]]],
[[[ 1.3543e-01, 3.1573e-02, -1.3481e-01, 1.6401e-01, 6.5917e-03],
[ 1.0258e+00, 2.3916e-01, -1.0212e+00, 1.2423e+00, 4.9931e-02],
[ 6.7794e-03, 1.5805e-03, -6.7486e-03, 8.2101e-03, 3.2998e-04],
[ 2.3896e+00, 5.5710e-01, -2.3787e+00, 2.8939e+00, 1.1631e-01]],
[[ 1.6727e+00, -1.7903e-01, 1.0400e-01, -1.0750e+00, -1.8860e+00],
[ 1.0096e-01, -1.0805e-02, 6.2766e-03, -6.4881e-02, -1.1383e-01],
[-9.5594e-01, 1.0231e-01, -5.9432e-02, 6.1435e-01, 1.0778e+00],
[ 3.6754e-01, -3.9337e-02, 2.2850e-02, -2.3620e-01, -4.1439e-01]],
[[ 1.3218e-01, 1.4414e-01, -1.0759e-01, 2.1424e-01, -4.5286e-01],
[-4.7222e-01, -5.1497e-01, 3.8439e-01, -7.6540e-01, 1.6179e+00],
[ 3.1224e-01, 3.4051e-01, -2.5417e-01, 5.0610e-01, -1.0698e+00],
[ 1.0856e-01, 1.1839e-01, -8.8369e-02, 1.7596e-01, -3.7194e-01]]]]])
>>> g = torch.sum(f, dim=1)
>>> g
tensor([[[[ 0.1807, 0.2097, -0.3144, -0.0307, -0.5085],
[ 1.0837, 0.4669, -1.2507, 0.9934, -0.6086],
[-0.0251, -0.1240, 0.1198, 0.1454, 0.3634],
[ 2.1370, -0.4378, -1.3760, 3.9811, 2.9929]],
[[ 1.8297, -0.1113, -0.0698, -1.4406, -1.8800],
[-0.1329, -0.1117, 0.2653, 0.4799, -0.1228],
[-0.6203, 0.2471, -0.4311, -0.1675, 1.0907],
[ 0.0585, -0.1727, 0.3651, 0.4836, -0.4262]],
[[-0.1344, 0.2650, 0.4419, -0.0680, -0.6477],
[-0.6088, -0.4530, 0.6660, -0.9100, 1.5180],
[ 0.3388, 0.3285, -0.3089, 0.5342, -1.0504],
[ 0.2797, 0.0408, -0.4410, 0.3571, -0.2469]]]])
>>> c
tensor([[[[ 0.1807, 0.2097, -0.3144, -0.0307, -0.5085],
[ 1.0837, 0.4669, -1.2507, 0.9934, -0.6086],
[-0.0251, -0.1240, 0.1198, 0.1454, 0.3634],
[ 2.1370, -0.4378, -1.3760, 3.9811, 2.9929]],
[[ 1.8297, -0.1113, -0.0698, -1.4406, -1.8800],
[-0.1329, -0.1117, 0.2653, 0.4799, -0.1228],
[-0.6203, 0.2471, -0.4311, -0.1675, 1.0907],
[ 0.0585, -0.1727, 0.3651, 0.4836, -0.4262]],
[[-0.1344, 0.2650, 0.4419, -0.0680, -0.6477],
[-0.6088, -0.4530, 0.6660, -0.9100, 1.5180],
[ 0.3388, 0.3285, -0.3089, 0.5342, -1.0504],
[ 0.2797, 0.0408, -0.4410, 0.3571, -0.2469]]]])
>>> g.shape
torch.Size([1, 3, 4, 5])
>>> c.shape
torch.Size([1, 3, 4, 5])
>>> a = torch.randn(1,2,3,4)
>>> b = torch.randn(1,5,2,4)
>>> c = torch.einsum('bhnm,bdhm->bdhn',a,b)
>>> a
tensor([[[[ 0.2333, -0.4853, -0.6971, -1.2038],
[-0.0799, 1.1245, 0.5262, 0.5892],
[-0.0836, -1.1434, -2.0070, -0.8349]],
[[ 0.1631, 0.8936, 0.9199, 1.1889],
[-0.4036, -0.1402, -1.1928, -1.4278],
[-1.6179, -0.2395, 1.2021, -0.3717]]]])
>>> b
tensor([[[[ 1.9957, -1.4477, -0.5855, 0.5520],
[ 0.2879, 0.5740, 0.4830, -1.0993]],
[[ 0.7745, -0.0874, -1.2396, -0.4058],
[-0.2885, -0.2290, -0.6926, -0.8500]],
[[-1.1198, 0.2451, 0.6506, 1.1265],
[ 0.4085, -1.4830, 0.2969, -0.9224]],
[[-0.0554, 0.8500, 0.1219, 0.9584],
[ 0.2657, 1.1979, -0.2655, -0.0734]],
[[-0.8810, 0.8404, -0.0308, 1.1748],
[-0.6998, 1.8537, -1.4930, -0.8909]]]])
>>> c
tensor([[[[ 0.9118, -1.7703, 2.2027],
[-0.3028, 0.7969, 0.3859]],
[[ 1.5757, -1.0516, 2.8619],
[-1.8994, 2.1884, 0.0050]],
[[-2.1897, 1.3712, -2.4329],
[-2.0823, 1.0060, 0.3942]],
[[-1.6641, 1.5891, -2.0121],
[ 0.7823, 0.1464, -1.0086]],
[[-2.0062, 1.6914, -1.8064],
[-0.8902, 3.0755, -0.7754]]]])
>>> d = a.unsqueeze(1)
>>> e = b.unsqueeze(3)
>>> f = d*e
>>> g = torch.sum(f, dim=-1)
>>> g
tensor([[[[ 0.9118, -1.7703, 2.2027],
[-0.3028, 0.7969, 0.3859]],
[[ 1.5757, -1.0516, 2.8619],
[-1.8994, 2.1884, 0.0050]],
[[-2.1897, 1.3712, -2.4329],
[-2.0823, 1.0060, 0.3942]],
[[-1.6641, 1.5891, -2.0121],
[ 0.7823, 0.1464, -1.0086]],
[[-2.0062, 1.6914, -1.8064],
[-0.8902, 3.0755, -0.7754]]]])
>>>