torch.einsum —— 爱因斯坦标记法下的张量相乘的理解

A t t n i j b h × V a l u e s j d b h = Z i d b h Attn_{ijbh} \times Values_{jdbh} = Z_{idbh} Attnijbh×Valuesjdbh=Zidbh (爱因斯坦标记法)

这里的 i , d , b , h i,d,b,h i,d,b,h于‘->’左右都出现了,被称作自由标,而 j j j 只在左边出现,被称为哑标,计算结果则是保留自由标,对哑标进行逐元素相乘并求和

展开为求和公式

Σ j A t t n i j b h × V a l u e s j d b h = Z i d b h \Sigma_{j} Attn_{ijbh} \times Values_{jdbh} = Z_{idbh} ΣjAttnijbh×Valuesjdbh=Zidbh

例如,对于单个head和单个batch的Attention Matrix,可以表示如下:

[ A t t n 11 b h A t t n 12 b h ⋯ A t t n 1 j b h A t t n 21 b h A t t n 22 b h ⋯ A t t n 2 j b h ⋮ ⋮ ⋱ ⋮ A t t n i 1 b h A t t n i 2 b h ⋯ A t t n i j b h ] \begin{bmatrix} Attn_{11bh} & Attn_{12bh} & \cdots & Attn_{1jbh} \\ Attn_{21bh} & Attn_{22bh} & \cdots & Attn_{2jbh} \\ \vdots & \vdots & \ddots & \vdots \\ Attn_{i1bh} & Attn_{i2bh} & \cdots & Attn_{ijbh} \\ \end{bmatrix} Attn11bhAttn21bhAttni1bhAttn12bhAttn22bhAttni2bhAttn1jbhAttn2jbhAttnijbh

相应的Value Matrix表示为

[ V a l u e s 1 d b h V a l u e s 2 d b h ⋯ V a l u e s j d b h ] \begin{bmatrix} Values_{1dbh} \\ Values_{2dbh} \\ \cdots \\ Values_{jdbh} \\ \end{bmatrix} Values1dbhValues2dbhValuesjdbh

我们可以得到Z

[ Z 11 b h Z 12 b h ⋯ Z 1 d b h ⋮ ⋮ ⋱ ⋮ Z i 1 b h Z i 2 b h ⋯ Z i d b h ] = [ A t t n 11 b h A t t n 12 b h ⋯ A t t n 1 j b h A t t n 21 b h A t t n 22 b h ⋯ A t t n 2 j b h ⋮ ⋮ ⋱ ⋮ A t t n i 1 b h A t t n i 2 b h ⋯ A t t n i j b h ] ⋅ [ V a l u e s 1 d b h V a l u e s 2 d b h ⋯ V a l u e s j d b h ] \begin{bmatrix} Z_{11bh} & Z_{12bh} & \cdots & Z_{1dbh} \\ \vdots & \vdots & \ddots & \vdots \\ Z_{i1bh} & Z_{i2bh} & \cdots & Z_{idbh} \\ \end{bmatrix} = \begin{bmatrix} Attn_{11bh} & Attn_{12bh} & \cdots & Attn_{1jbh} \\ Attn_{21bh} & Attn_{22bh} & \cdots & Attn_{2jbh} \\ \vdots & \vdots & \ddots & \vdots \\ Attn_{i1bh} & Attn_{i2bh} & \cdots & Attn_{ijbh} \\ \end{bmatrix} \cdot \begin{bmatrix} Values_{1dbh} \\ Values_{2dbh} \\ \cdots \\ Values_{jdbh} \\ \end{bmatrix} Z11bhZi1bhZ12bhZi2bhZ1dbhZidbh = Attn11bhAttn21bhAttni1bhAttn12bhAttn22bhAttni2bhAttn1jbhAttn2jbhAttnijbh Values1dbhValues2dbhValuesjdbh

其中

Z i d b h ∣ i = 1 , d = 1 = [ A t t n i j b h ∣ i = 1 , j = 1 A t t n i j b h ∣ i = 1 , j = 2 ⋯ A t t n 1 j b h ] ⋅ [ V a l u e s j d b h ∣ j = 1 V a l u e s j d b h ∣ j = 2 ⋯ V a l u e s j d b h ] Z_{idbh} \vert _{i=1,d=1} = \begin{bmatrix} Attn_{ijbh} \vert _{i=1,j=1} & Attn_{ijbh} \vert _{i=1,j=2} & \cdots & Attn_{1jbh} \end{bmatrix} \cdot \begin{bmatrix} Values_{jdbh}\vert _{j=1} \\ Values_{jdbh}\vert _{j=2} \\ \cdots \\ Values_{jdbh} \\ \end{bmatrix} Zidbhi=1,d=1=[Attnijbhi=1,j=1Attnijbhi=1,j=2Attn1jbh] Valuesjdbhj=1Valuesjdbhj=2Valuesjdbh

从中我们可以看出所谓的对哑标 j j j 进行逐元素相乘并求和。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值