如何快速看出矩阵乘法的时间复杂度

以 Attention Score 的计算为例
A t t n ( K , Q , V ) = S o f t m a x ( Q ⋅ K T / d ) ⋅ V Attn(K,Q,V) = Softmax(Q\cdot K^T/\sqrt{d})\cdot V Attn(K,Q,V)=Softmax(QKT/d )V
咱姑且把 Softmax 和 Softmax里面的除以 d \sqrt{d} d 去掉(其运算时间复杂度小),表示为
A t t n ( K , Q , V ) = Q ⋅ K T ⋅ V Attn(K,Q,V) = Q\cdot K^T\cdot V Attn(K,Q,V)=QKTV
其中, Q , K , V ∈ R N × d Q,K,V \in \mathbb{R}^{N\times d} Q,K,VRN×d N N N 是token的数量, d d d 是每个token的维度,一般认为 N N N> d d d

Q ⋅ K T Q\cdot K^T QKT 从矩阵乘法上看维度变换是 N × d × d × N N\times d \times d \times N N×d×d×N,得到的矩阵维度是 N × N N\times N N×N,即得到的矩阵有 N 2 N^2 N2 个元素,每个元素需要经过d个元素相乘再相加得到(加权求和),所以 Q ⋅ K T Q\cdot K^T QKT 计算的时间复杂度为 O ( N 2 d ) O(N^2d) O(N2d)

  • 总结一个快速得出结论的方法

如果你不想鸟我上面写的,你只需要按照这个规则来看

比如两个矩阵 M ⋅ N , M ∈ R m × n , M ∈ R n × k M\cdot N, M\in\mathbb{R}^{m\times n}, M\in\mathbb{R}^{n\times k} MN,MRm×n,MRn×k

按照维度表示为 m × n × n × k m\times n \times n \times k m×n×n×k只需要把中间的两个 n n n 删掉一个即可表示时间复杂度,为 O ( m × n × k ) O(m\times n\times k) O(m×n×k)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值