【模型转换tips】爱因斯坦求和约定(einsum)的替换

pytorch转onnx时对einsum进行替换

 

1、einsum简介

一文学会 Pytorch 中的 einsum

2、替换

	# a = torch.einsum('bhnm,bdhm->bdhn', prob, value)
    prob = prob.unsqueeze(1)
    value = value.unsqueeze(3)
    a = prob*value
    a = torch.sum(a, dim=-1)
  • 输出:b d h n **
  • 输入:b # h n m
  • 输入:b d h # m

如上所示把输入输出维度对齐,将两个输入在标记#的地方加个为1的维度,然后相乘,得到的结果在标记**的地方求和,转成的onnx模型如下所示:
onnx

3、测试

>>> a = torch.randn(1,2,3,4)
>>> a
tensor([[[[ 0.3094,  0.3956, -0.2181, -1.7280],
          [ 0.2459, -0.3665,  0.5259, -0.4842],
          [ 0.3226,  0.1653, -0.0322, -0.2071]],

         [[ 0.1408,  1.0664,  0.0070,  2.4841],
          [-1.2721, -0.0768,  0.7270, -0.2795],
          [-0.3427,  1.2243, -0.8096, -0.2815]]]])
>>> b = torch.randn(1,2,3,5)
>>> b
tensor([[[[ 0.1462,  0.5758, -0.5803, -0.6292, -1.6647],
          [ 0.6382,  0.2753, -0.7068, -1.4866,  0.0245],
          [-0.8264,  0.3745,  1.7032, -0.8747, -0.6040]],

         [[ 0.9619,  0.2243, -0.9576,  1.1650,  0.0468],
          [-1.3150,  0.1407, -0.0818,  0.8451,  1.4826],
          [-0.3857, -0.4206,  0.3140, -0.6252,  1.3214]]]])
>>> c = torch.einsum('bdhn,bdhm->bhnm',a,b)
>>> c
tensor([[[[ 0.1807,  0.2097, -0.3144, -0.0307, -0.5085],
          [ 1.0837,  0.4669, -1.2507,  0.9934, -0.6086],
          [-0.0251, -0.1240,  0.1198,  0.1454,  0.3634],
          [ 2.1370, -0.4378, -1.3760,  3.9811,  2.9929]],

         [[ 1.8297, -0.1113, -0.0698, -1.4406, -1.8800],
          [-0.1329, -0.1117,  0.2653,  0.4799, -0.1228],
          [-0.6203,  0.2471, -0.4311, -0.1675,  1.0907],
          [ 0.0585, -0.1727,  0.3651,  0.4836, -0.4262]],

         [[-0.1344,  0.2650,  0.4419, -0.0680, -0.6477],
          [-0.6088, -0.4530,  0.6660, -0.9100,  1.5180],
          [ 0.3388,  0.3285, -0.3089,  0.5342, -1.0504],
          [ 0.2797,  0.0408, -0.4410,  0.3571, -0.2469]]]])
>>> d = a.reshape(1,2,3,4,1)
>>> d
tensor([[[[[ 0.3094],
           [ 0.3956],
           [-0.2181],
           [-1.7280]],

          [[ 0.2459],
           [-0.3665],
           [ 0.5259],
           [-0.4842]],

          [[ 0.3226],
           [ 0.1653],
           [-0.0322],
           [-0.2071]]],


         [[[ 0.1408],
           [ 1.0664],
           [ 0.0070],
           [ 2.4841]],

          [[-1.2721],
           [-0.0768],
           [ 0.7270],
           [-0.2795]],

          [[-0.3427],
           [ 1.2243],
           [-0.8096],
           [-0.2815]]]]])
>>> e = b.reshape(1,2,3,1,5)
>>> e
tensor([[[[[ 0.1462,  0.5758, -0.5803, -0.6292, -1.6647]],

          [[ 0.6382,  0.2753, -0.7068, -1.4866,  0.0245]],

          [[-0.8264,  0.3745,  1.7032, -0.8747, -0.6040]]],


         [[[ 0.9619,  0.2243, -0.9576,  1.1650,  0.0468]],

          [[-1.3150,  0.1407, -0.0818,  0.8451,  1.4826]],

          [[-0.3857, -0.4206,  0.3140, -0.6252,  1.3214]]]]])
>>> f = d*e
>>> f
tensor([[[[[ 4.5235e-02,  1.7816e-01, -1.7956e-01, -1.9469e-01, -5.1511e-01],
           [ 5.7831e-02,  2.2777e-01, -2.2955e-01, -2.4890e-01, -6.5854e-01],
           [-3.1885e-02, -1.2558e-01,  1.2656e-01,  1.3723e-01,  3.6309e-01],
           [-2.5261e-01, -9.9492e-01,  1.0027e+00,  1.0872e+00,  2.8766e+00]],

          [[ 1.5694e-01,  6.7712e-02, -1.7380e-01, -3.6558e-01,  6.0128e-03],
           [-2.3387e-01, -1.0091e-01,  2.5900e-01,  5.4478e-01, -8.9603e-03],
           [ 3.3563e-01,  1.4481e-01, -3.7171e-01, -7.8185e-01,  1.2859e-02],
           [-3.0900e-01, -1.3332e-01,  3.4221e-01,  7.1980e-01, -1.1839e-02]],

          [[-2.6661e-01,  1.2083e-01,  5.4949e-01, -2.8221e-01, -1.9486e-01],
           [-1.3662e-01,  6.1917e-02,  2.8157e-01, -1.4461e-01, -9.9848e-02],
           [ 2.6572e-02, -1.2043e-02, -5.4765e-02,  2.8127e-02,  1.9421e-02],
           [ 1.7111e-01, -7.7549e-02, -3.5265e-01,  1.8112e-01,  1.2506e-01]]],


         [[[ 1.3543e-01,  3.1573e-02, -1.3481e-01,  1.6401e-01,  6.5917e-03],
           [ 1.0258e+00,  2.3916e-01, -1.0212e+00,  1.2423e+00,  4.9931e-02],
           [ 6.7794e-03,  1.5805e-03, -6.7486e-03,  8.2101e-03,  3.2998e-04],
           [ 2.3896e+00,  5.5710e-01, -2.3787e+00,  2.8939e+00,  1.1631e-01]],

          [[ 1.6727e+00, -1.7903e-01,  1.0400e-01, -1.0750e+00, -1.8860e+00],
           [ 1.0096e-01, -1.0805e-02,  6.2766e-03, -6.4881e-02, -1.1383e-01],
           [-9.5594e-01,  1.0231e-01, -5.9432e-02,  6.1435e-01,  1.0778e+00],
           [ 3.6754e-01, -3.9337e-02,  2.2850e-02, -2.3620e-01, -4.1439e-01]],

          [[ 1.3218e-01,  1.4414e-01, -1.0759e-01,  2.1424e-01, -4.5286e-01],
           [-4.7222e-01, -5.1497e-01,  3.8439e-01, -7.6540e-01,  1.6179e+00],
           [ 3.1224e-01,  3.4051e-01, -2.5417e-01,  5.0610e-01, -1.0698e+00],
           [ 1.0856e-01,  1.1839e-01, -8.8369e-02,  1.7596e-01, -3.7194e-01]]]]])
>>> g = torch.sum(f, dim=1)
>>> g
tensor([[[[ 0.1807,  0.2097, -0.3144, -0.0307, -0.5085],
          [ 1.0837,  0.4669, -1.2507,  0.9934, -0.6086],
          [-0.0251, -0.1240,  0.1198,  0.1454,  0.3634],
          [ 2.1370, -0.4378, -1.3760,  3.9811,  2.9929]],

         [[ 1.8297, -0.1113, -0.0698, -1.4406, -1.8800],
          [-0.1329, -0.1117,  0.2653,  0.4799, -0.1228],
          [-0.6203,  0.2471, -0.4311, -0.1675,  1.0907],
          [ 0.0585, -0.1727,  0.3651,  0.4836, -0.4262]],

         [[-0.1344,  0.2650,  0.4419, -0.0680, -0.6477],
          [-0.6088, -0.4530,  0.6660, -0.9100,  1.5180],
          [ 0.3388,  0.3285, -0.3089,  0.5342, -1.0504],
          [ 0.2797,  0.0408, -0.4410,  0.3571, -0.2469]]]])
>>> c
tensor([[[[ 0.1807,  0.2097, -0.3144, -0.0307, -0.5085],
          [ 1.0837,  0.4669, -1.2507,  0.9934, -0.6086],
          [-0.0251, -0.1240,  0.1198,  0.1454,  0.3634],
          [ 2.1370, -0.4378, -1.3760,  3.9811,  2.9929]],

         [[ 1.8297, -0.1113, -0.0698, -1.4406, -1.8800],
          [-0.1329, -0.1117,  0.2653,  0.4799, -0.1228],
          [-0.6203,  0.2471, -0.4311, -0.1675,  1.0907],
          [ 0.0585, -0.1727,  0.3651,  0.4836, -0.4262]],

         [[-0.1344,  0.2650,  0.4419, -0.0680, -0.6477],
          [-0.6088, -0.4530,  0.6660, -0.9100,  1.5180],
          [ 0.3388,  0.3285, -0.3089,  0.5342, -1.0504],
          [ 0.2797,  0.0408, -0.4410,  0.3571, -0.2469]]]])
>>> g.shape
torch.Size([1, 3, 4, 5])
>>> c.shape
torch.Size([1, 3, 4, 5])
>>> a = torch.randn(1,2,3,4)
>>> b = torch.randn(1,5,2,4)
>>> c = torch.einsum('bhnm,bdhm->bdhn',a,b)
>>> a
tensor([[[[ 0.2333, -0.4853, -0.6971, -1.2038],
          [-0.0799,  1.1245,  0.5262,  0.5892],
          [-0.0836, -1.1434, -2.0070, -0.8349]],

         [[ 0.1631,  0.8936,  0.9199,  1.1889],
          [-0.4036, -0.1402, -1.1928, -1.4278],
          [-1.6179, -0.2395,  1.2021, -0.3717]]]])
>>> b
tensor([[[[ 1.9957, -1.4477, -0.5855,  0.5520],
          [ 0.2879,  0.5740,  0.4830, -1.0993]],

         [[ 0.7745, -0.0874, -1.2396, -0.4058],
          [-0.2885, -0.2290, -0.6926, -0.8500]],

         [[-1.1198,  0.2451,  0.6506,  1.1265],
          [ 0.4085, -1.4830,  0.2969, -0.9224]],

         [[-0.0554,  0.8500,  0.1219,  0.9584],
          [ 0.2657,  1.1979, -0.2655, -0.0734]],

         [[-0.8810,  0.8404, -0.0308,  1.1748],
          [-0.6998,  1.8537, -1.4930, -0.8909]]]])
>>> c
tensor([[[[ 0.9118, -1.7703,  2.2027],
          [-0.3028,  0.7969,  0.3859]],

         [[ 1.5757, -1.0516,  2.8619],
          [-1.8994,  2.1884,  0.0050]],

         [[-2.1897,  1.3712, -2.4329],
          [-2.0823,  1.0060,  0.3942]],

         [[-1.6641,  1.5891, -2.0121],
          [ 0.7823,  0.1464, -1.0086]],

         [[-2.0062,  1.6914, -1.8064],
          [-0.8902,  3.0755, -0.7754]]]])
>>> d = a.unsqueeze(1)
>>> e = b.unsqueeze(3)
>>> f = d*e
>>> g = torch.sum(f, dim=-1)
>>> g
tensor([[[[ 0.9118, -1.7703,  2.2027],
          [-0.3028,  0.7969,  0.3859]],

         [[ 1.5757, -1.0516,  2.8619],
          [-1.8994,  2.1884,  0.0050]],

         [[-2.1897,  1.3712, -2.4329],
          [-2.0823,  1.0060,  0.3942]],

         [[-1.6641,  1.5891, -2.0121],
          [ 0.7823,  0.1464, -1.0086]],

         [[-2.0062,  1.6914, -1.8064],
          [-0.8902,  3.0755, -0.7754]]]])
>>> 

  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值