ChatGLMv2 RoPE的代码实现

前言

参考链接:https://zhuanlan.zhihu.com/p/645263524
一开始对tensor的 reshape, 片选操作不熟, 还以为v2没有做rotate动作, 请教了之后才算搞懂了。

代码注释

@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    # Tag:q, k: [sq, b, np, hn], hn=128
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    # rot_dim =64
    rot_dim = rope_cache.shape[-2] * 2
    # Tag: x-->(sq, b, 64, 2)
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]

    # (sq,b, 32, 2)
    # Tag:总体上就是用的reshape+片选 实现rotate 交换动作;
    # 最后一维是2, 举例如下:[q1, q2],[q3, q4]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    # (sq,b, 1, 32, 2); 同样最后一维是2, 举例如下:[cos1, sin1],[cos2, sin2], ...
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    # rope_cache[..., 0]; shape-->(sq,b, 32)
    # xshaped[..., 0]-->[q1, q3,...]; 利用片选分离出单数, 双数的q
    x_out2 = torch.stack(
        [
            # [q1, q3, ] *[cos1, cos2, cos3] - [q2, q4, ] *[sin1, sin2, sin3]
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            # [q2, q4, ] *[cos1, cos2, cos3] - [q1, q3, ] *[sin1, sin2, sin3]
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)

最终实现下图公式的效果
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值