torch.baddbmm算子定义如下
def baddbmm(input: Tensor, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None)
计算结果大致为:
input * beta + batch1*batch2*alpha
在一些算法框架,如Megatron-DeepSpeed下,对于input会用empty定义出来,这就导致会进行一个empty*beta的计算过程,此时需要注意,如果beta是0的话,inptu*beta是要跳过的,不然的话如果empty过程中引入了inf值,会导致计算结果中存在NAN。
# pre-allocating result tensor: [b * np, sq, sk]
if alibi is None:
matmul_result = torch.zeros(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=get_current_device())
else:
matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]]
alibi_emb = alibi[:output_size[0]*output_size[1], :, :output_size[3]]
# Rotary embeddings
if self.position_embedding_type == PositionEmbeddingType.rotary:
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
seq_len = key_layer.shape[0]
offset = 0
if layer_past is not None and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
if self.glm:
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len // 2)
rot_dim = query_layer.shape[-1] // 2
query_layer_rope, query_layer_pass = query_layer[..., :rot_dim], query_layer[..., rot_dim:]
key_layer_rope, key_layer_pass = key_layer[..., :rot_dim], key_layer[..., rot_dim:]
query_layer_rope, key_layer_rope = apply_rotary_fn(query_layer_rope, key_layer_rope, cos, sin, offset=offset)
query_layer = torch.cat((query_layer_rope, query_layer_pass), dim=-1)
key_layer = torch.cat((key_layer_rope, key_layer_pass), dim=-1)
else:
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
# Raw attention scores. [b * np, sq, sk]
beta = 0.0
if alibi is not None:
beta = 1.0
if alibi is None:
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=beta, alpha=(1.0/self.norm_factor))
else:
matmul_result = torch.bmm(query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
) * (1.0/self.norm_factor) + alibi_emb
在torch的CPU实现中,也是规避了0带来的NAN的异常值引入。
if (is_bmm) {
r2[j] = acc_value;
} else {
// For beta == 0, the r's value will be ignored, especially for nan value.
if (beta == opmath_t{0}) {
r2[j] = alpha * acc_value;
} else {
r2[j] = static_cast<opmath_t>(r2[j]) * beta + alpha * acc_value;
}
}