一个torch.baddbmm算子实现的细节

torch.baddbmm算子定义如下

def baddbmm(input: Tensor, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None)

计算结果大致为:
input * beta + batch1*batch2*alpha

在一些算法框架,如Megatron-DeepSpeed下,对于input会用empty定义出来,这就导致会进行一个empty*beta的计算过程,此时需要注意,如果beta是0的话,inptu*beta是要跳过的,不然的话如果empty过程中引入了inf值,会导致计算结果中存在NAN。

        # pre-allocating result tensor: [b * np, sq, sk]
        if alibi is None:
            matmul_result = torch.zeros(
                output_size[0]*output_size[1],
                output_size[2],
                output_size[3],
                dtype=query_layer.dtype,
                device=get_current_device())
        else:
            matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]]
            alibi_emb = alibi[:output_size[0]*output_size[1], :, :output_size[3]]

        # Rotary embeddings
        if self.position_embedding_type == PositionEmbeddingType.rotary:
            apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb

            seq_len = key_layer.shape[0]
            offset = 0
            if layer_past is not None and layer_past.numel() > 0:
                offset = layer_past[0].shape[0]
                seq_len += offset

            if self.glm:
                cos, sin = self.rotary_emb(value_layer, seq_len=seq_len // 2)
                rot_dim = query_layer.shape[-1] // 2
                query_layer_rope, query_layer_pass = query_layer[..., :rot_dim], query_layer[..., rot_dim:]
                key_layer_rope, key_layer_pass = key_layer[..., :rot_dim], key_layer[..., rot_dim:]
                query_layer_rope, key_layer_rope = apply_rotary_fn(query_layer_rope, key_layer_rope, cos, sin, offset=offset)
                query_layer = torch.cat((query_layer_rope, query_layer_pass), dim=-1)
                key_layer = torch.cat((key_layer_rope, key_layer_pass), dim=-1)
            else:
                cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
                query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)

        # Raw attention scores. [b * np, sq, sk]
        beta = 0.0
        if alibi is not None:
            beta = 1.0
        if alibi is None:
            matmul_result = torch.baddbmm(
                matmul_result,
                query_layer.transpose(0, 1),   # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=beta, alpha=(1.0/self.norm_factor))
        else:
            matmul_result = torch.bmm(query_layer.transpose(0, 1),   # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            ) * (1.0/self.norm_factor) + alibi_emb

在torch的CPU实现中,也是规避了0带来的NAN的异常值引入。

            if (is_bmm) {
              r2[j] = acc_value;
            } else {
              // For beta == 0, the r's value will be ignored, especially for nan value.
              if (beta == opmath_t{0}) {
                r2[j] = alpha * acc_value;
              } else {
                r2[j] = static_cast<opmath_t>(r2[j]) * beta + alpha * acc_value;
              }
            }

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

碳纤维石头君

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值