项目场景:
einsum是个宝,写网络时可以显示地指定tensor reshape的方式,也可以直接指定纬度进行矩阵计算。然而,torch.einops和einops.einsum (或者import einsum)的命令方式不一样,会导致意外的没有按照预设的方式进行矩阵计算的情况。
问题描述
假设两个矩阵的形状为BPD
和bpd
,其中B=b,P=p,D=d
,就是两个一样的transformer出来的tokens,用来做CLIP。我想让它们在D上做乘法,B上做加法,P上维持形状,即输出BPp
的相似度矩阵,有两种实现方式:
import torch
from einops import einsum
a = torch.ones([32,256,768]) # B P D
b = torch.ones([32,256,768]) # b p d
c_0 = torch.einsum('BPD, BpD -> BPp',a,b) # 32, 256, 256
c_1 = einsum(a,b,'B P D, B p D -> B P p') # 32, 256, 256
这里我通过将两个tensor的第三维都用D表示,来指定在D上做乘法,但如果分别用Dd表示,torch.einsum也能做乘法,但输出的结果和上面的会不一样:
c_2 = torch.einsum('BPD, Bpd -> BPp',a,b) # 32, 256, 256
c_3 = einsum(a,b,'B P D, B p d -> B P p') # 32, 256, 256
# >>> c_0[0,0,0]
# tensor(768.)
# >>> c_1[0,0,0]
# tensor(768.)
# >>> c_2[0,0,0]
# tensor(589824.)
# >>> c_3[0,0,0]
# tensor(589824.)
原因分析:
这是因为,两个einops,都是相同字符串做乘法,不同字符串做拼接或者加法。在第一个例子中,第三维是同一个D,所以是两个768的向量做了哈达马积(按位相乘),而第二种情况,是将两个768的全1向量,拓展成矩阵算了哈达马积。或者换一种说法,一个是[1,768]@[768,1]
,一个是[768,1]@[1,768]
解决方案:
使用einsum时,一定注意字符串内大小写的含义,或者,只用einsum或einops.rearrange进行reshape,然后老老实实用torch进行矩阵乘法。