Mamba: LLVM ERROR: Cannot select: intrinsic %llvm.nvvm.shfl.sync.bfly.i32

mamba中部分方法使用了triton,计算能力低于7.0的gpu不支持,所以会报错。nullicon-default.png?t=N7T8https://developer.nvidia.com/cuda-gpus

解决方法一是按照上述链接中计算能力购买显卡,方法二是将使用了triton的方法利用原生pytorch代码替换。注意:替换后仅仅是能够保证程序运行,并不能保证结果正确,事实上在我更换后测试的结果是错误的。修改后程序由gpt直接生成。

第一个需要更改的地方在site-package/mamba_ssm/ops/triton/layernorm.py中的

@triton.autotune(
    configs=[
        triton.Config({}, num_warps=1),
        triton.Config({}, num_warps=2),
        triton.Config({}, num_warps=4),
        triton.Config({}, num_warps=8),
        triton.Config({}, num_warps=16),
        triton.Config({}, num_warps=32),
    ],
    key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
#@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
#@triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
    X,  # pointer to the input
    Y,  # pointer to the output
    W,  # pointer to the weights
    B,  # pointer to the biases
    RESIDUAL,  # pointer to the residual
    RESIDUAL_OUT,  # pointer to the residual
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
    stride_x_row,  # how much to increase the pointer when moving by 1 row
    stride_y_row,
    stride_res_row,
    stride_res_out_row,
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
    IS_RMS_NORM: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HAS_RESIDUAL: tl.constexpr,
    STORE_RESIDUAL_OUT: tl.constexpr,
    HAS_BIAS: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    X += row * stride_x_row
    Y += row * stride_y_row
    if HAS_RESIDUAL:
        RESIDUAL += row * stride_res_row
    if STORE_RESIDUAL_OUT:
        RESIDUAL_OUT += row * stride_res_out_row
    # Compute mean and variance
    cols = tl.arange(0, BLOCK_N)
    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
    if HAS_RESIDUAL:
        residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
        x += residual
    if STORE_RESIDUAL_OUT:
        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
    if not IS_RMS_NORM:
        mean = tl.sum(x, axis=0) / N
        tl.store(Mean + row, mean)
        xbar = tl.where(cols < N, x - mean, 0.0)
        var = tl.sum(xbar * xbar, axis=0) / N
    else:
        xbar = tl.where(cols < N, x, 0.0)
        var = tl.sum(xbar * xbar, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    tl.store(Rstd + row, rstd)
    # Normalize and apply linear transformation
    mask = cols < N
    w = tl.load(W + cols, mask=mask).to(tl.float32)
    if HAS_BIAS:
        b = tl.load(B + cols, mask=mask).to(tl.float32)
    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
    y = x_hat * w + b if HAS_BIAS else x_hat * w
    # Write output
    tl.store(Y + cols, y, mask=mask)



with torch.cuda.device(x.device.index):
    _layer_norm_fwd_1pass_kernel[(M,)](
            x,
            y,
            weight,
            bias,
            residual,
            residual_out,
            mean,
            rstd,
            x.stride(0),
            y.stride(0),
            residual.stride(0) if residual is not None else 0,
            residual_out.stride(0) if residual_out is not None else 0,
            N,
            eps,
            is_rms_norm,
            BLOCK_N,
            residual is not None,
            residual_out is not None,
            bias is not None,
        )

替换为

def layer_norm_fwd_1pass_kernel(
    X,  # input tensor
    Y,  # output tensor
    W,  # weights tensor
    B,  # biases tensor
    RESIDUAL,  # residual tensor
    RESIDUAL_OUT,  # residual output tensor
    Mean,  # mean tensor
    Rstd,  # rstd tensor
    stride_x_row,  # how much to increase the pointer when moving by 1 row
    stride_y_row,
    stride_res_row,
    stride_res_out_row,
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
    IS_RMS_NORM,
    BLOCK_N,
    HAS_RESIDUAL,
    STORE_RESIDUAL_OUT,
    HAS_BIAS,
):
    batch_size, N = X.shape
    X = X.float()
    if HAS_RESIDUAL:
        RESIDUAL = RESIDUAL.float()
        X = X + RESIDUAL

    if STORE_RESIDUAL_OUT:
        RESIDUAL_OUT = X.clone()
    mean = X.mean(dim=1, keepdim=True)
    if not IS_RMS_NORM:
        xbar = X - mean
        var = xbar.pow(2).mean(dim=1, keepdim=True)
    else:
        xbar = X
        var = X.pow(2).mean(dim=1, keepdim=True)

    rstd = 1.0 / torch.sqrt(var + eps)
    if not IS_RMS_NORM:
        x_hat = xbar * rstd
    else:
        x_hat = X * rstd

    Y = x_hat * W
    if HAS_BIAS:
        Y = Y + B

    if STORE_RESIDUAL_OUT:
        return Y, mean, rstd, RESIDUAL_OUT
    else:
        return X



​
with torch.cuda.device(x.device.index):
        if residual_out is not None:
            y, mean, rstd, residual_out = _layer_norm_fwd_1pass_kernel(
            x,
            y,
            weight,
            bias,
            residual,
            residual_out,
            mean,
            rstd,
            x.stride(0),
            y.stride(0),
            residual.stride(0) if residual is not None else 0,
            residual_out.stride(0) if residual_out is not None else 0,
            N,
            eps,
            is_rms_norm,
            BLOCK_N,
            residual is not None,
            residual_out is not None,
            bias is not None,
        )
        else:
            x = _layer_norm_fwd_1pass_kernel(
            x,
            y,
            weight,
            bias,
            residual,
            residual_out,
            mean,
            rstd,
            x.stride(0),
            y.stride(0),
            residual.stride(0) if residual is not None else 0,
            residual_out.stride(0) if residual_out is not None else 0,
            N,
            eps,
            is_rms_norm,
            BLOCK_N,
            residual is not None,
            residual_out is not None,
            bias is not None,
        )

site-package/mamba_ssm/ops/triton/selective_state_update.py中的

@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
def _selective_scan_update_kernel(
    # Pointers to matrices
    state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
    # Matrix dimensions
    batch, dim, dstate,
    # Strides
    stride_state_batch, stride_state_dim, stride_state_dstate,
    stride_x_batch, stride_x_dim,
    stride_dt_batch, stride_dt_dim,
    stride_dt_bias_dim,
    stride_A_dim, stride_A_dstate,
    stride_B_batch, stride_B_dstate,
    stride_C_batch, stride_C_dstate,
    stride_D_dim,
    stride_z_batch, stride_z_dim,
    stride_out_batch, stride_out_dim,
    # Meta-parameters
    DT_SOFTPLUS: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    HAS_DT_BIAS: tl.constexpr,
    HAS_D: tl.constexpr,
    HAS_Z: tl.constexpr,
    BLOCK_SIZE_DSTATE: tl.constexpr,
):
    pid_m = tl.program_id(axis=0)
    pid_b = tl.program_id(axis=1)
    state_ptr += pid_b * stride_state_batch
    x_ptr += pid_b * stride_x_batch
    dt_ptr += pid_b * stride_dt_batch
    B_ptr += pid_b * stride_B_batch
    C_ptr += pid_b * stride_C_batch
    if HAS_Z:
        z_ptr += pid_b * stride_z_batch
    out_ptr += pid_b * stride_out_batch

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
    state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
    x_ptrs = x_ptr + offs_m * stride_x_dim
    dt_ptrs = dt_ptr + offs_m * stride_dt_dim
    if HAS_DT_BIAS:
        dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
    A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
    B_ptrs = B_ptr + offs_n * stride_B_dstate
    C_ptrs = C_ptr + offs_n * stride_C_dstate
    if HAS_D:
        D_ptrs = D_ptr + offs_m * stride_D_dim
    if HAS_Z:
        z_ptrs = z_ptr + offs_m * stride_z_dim
    out_ptrs = out_ptr + offs_m * stride_out_dim

    state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
    x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
    dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
    if HAS_DT_BIAS:
        dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
    if DT_SOFTPLUS:
        dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
    A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
    dA = tl.exp(A * dt[:, None])
    B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
    C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
    if HAS_D:
        D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
    if HAS_Z:
        z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)

    dB = B[None, :] * dt[:, None]
    state = state * dA + dB * x[:, None]
    tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
    out = tl.sum(state * C[None, :], axis=1)
    if HAS_D:
        out += x * D
    if HAS_Z:
        out *= z * tl.sigmoid(z)
    tl.store(out_ptrs, out, mask=offs_m < dim)



​
with torch.cuda.device(x.device.index):
        _selective_scan_update_kernel[grid](
            state, x, dt, dt_bias, A, B, C, D, z, out,
            batch, dim, dstate,
            state.stride(0), state.stride(1), state.stride(2),
            x.stride(0), x.stride(1),
            dt.stride(0), dt.stride(1),
            dt_bias.stride(0) if dt_bias is not None else 0,
            A.stride(0), A.stride(1),
            B.stride(0), B.stride(1),
            C.stride(0), C.stride(1),
            D.stride(0) if D is not None else 0,
            z_strides[0], z_strides[1],
            out.stride(0), out.stride(1),
            dt_softplus,
            BLOCK_SIZE_M,
            # num_warps=num_warps,
        )

替换为

def selective_scan_update_kernel(# Pointers to matrices
    state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
    # Matrix dimensions
    batch, dim, dstate,
    # Strides
    stride_state_batch, stride_state_dim, stride_state_dstate,
    stride_x_batch, stride_x_dim,
    stride_dt_batch, stride_dt_dim,
    stride_dt_bias_dim,
    stride_A_dim, stride_A_dstate,
    stride_B_batch, stride_B_dstate,
    stride_C_batch, stride_C_dstate,
    stride_D_dim,
    stride_z_batch, stride_z_dim,
    stride_out_batch, stride_out_dim,
    # Meta-parameters
    DT_SOFTPLUS: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    # HAS_DT_BIAS: tl.constexpr,
    # HAS_D: tl.constexpr,
    # HAS_Z: tl.constexpr,
    # BLOCK_SIZE_DSTATE: tl.constexpr,
    ):
    # Assuming state, x, dt, A, B, C, D, z, and dt_bias are PyTorch tensors with appropriate shapes
    batch, dim, dstate = state_ptr.shape

    HAS_DT_BIAS = dt_bias_ptr is not None
    HAS_D = D_ptr is not None
    HAS_Z = z_ptr is not None
    
    if HAS_DT_BIAS:
        dt_ptr = dt_ptr + dt_bias_ptr
    if DT_SOFTPLUS:
        dt_ptr = torch.where(dt_ptr <= 20.0, torch.log1p(torch.exp(dt_ptr)), dt_ptr)
    
    dA = torch.exp(A_ptr * dt_ptr.unsqueeze(-1))
    dB = B_ptr * dt_ptr.unsqueeze(-1)
    
    state_ptr = state_ptr * dA + dB * x_ptr.unsqueeze(-1)
    
    out_ptr = torch.sum(state_ptr * C_ptr.unsqueeze(0), dim=-1)
    if HAS_D:
        out_ptr = out_ptr + x_ptr * D_ptr
    if HAS_Z:
        out_ptr = out_ptr * z_ptr * torch.sigmoid(z_ptr)
    return out_ptr



​
with torch.cuda.device(x.device.index):
        out = selective_scan_update_kernel(
            state, x, dt, dt_bias, A, B, C, D, z, out,
            batch, dim, dstate,
            state.stride(0), state.stride(1), state.stride(2),
            x.stride(0), x.stride(1),
            dt.stride(0), dt.stride(1),
            dt_bias.stride(0) if dt_bias is not None else 0,
            A.stride(0), A.stride(1),
            B.stride(0), B.stride(1),
            C.stride(0), C.stride(1),
            D.stride(0) if D is not None else 0,
            z_strides[0], z_strides[1],
            out.stride(0), out.stride(1),
            dt_softplus,
            BLOCK_SIZE_M,
            # num_warps=num_warps,
        )

若还有其他请问,请参照Mamba 环境安装踩坑问题汇总及解决方法_error: could not build wheels for causal-conv1d, w-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/yyywxk/article/details/136071016

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值