mamba中部分方法使用了triton,计算能力低于7.0的gpu不支持,所以会报错。nullhttps://developer.nvidia.com/cuda-gpus
解决方法一是按照上述链接中计算能力购买显卡,方法二是将使用了triton的方法利用原生pytorch代码替换。注意:替换后仅仅是能够保证程序运行,并不能保证结果正确,事实上在我更换后测试的结果是错误的。修改后程序由gpt直接生成。
第一个需要更改的地方在site-package/mamba_ssm/ops/triton/layernorm.py中的
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
#@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
#@triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
N, # number of columns in X
eps, # epsilon to avoid division by zero
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
# Write output
tl.store(Y + cols, y, mask=mask)
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
bias is not None,
)
替换为
def layer_norm_fwd_1pass_kernel(
X, # input tensor
Y, # output tensor
W, # weights tensor
B, # biases tensor
RESIDUAL, # residual tensor
RESIDUAL_OUT, # residual output tensor
Mean, # mean tensor
Rstd, # rstd tensor
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
N, # number of columns in X
eps, # epsilon to avoid division by zero
IS_RMS_NORM,
BLOCK_N,
HAS_RESIDUAL,
STORE_RESIDUAL_OUT,
HAS_BIAS,
):
batch_size, N = X.shape
X = X.float()
if HAS_RESIDUAL:
RESIDUAL = RESIDUAL.float()
X = X + RESIDUAL
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT = X.clone()
mean = X.mean(dim=1, keepdim=True)
if not IS_RMS_NORM:
xbar = X - mean
var = xbar.pow(2).mean(dim=1, keepdim=True)
else:
xbar = X
var = X.pow(2).mean(dim=1, keepdim=True)
rstd = 1.0 / torch.sqrt(var + eps)
if not IS_RMS_NORM:
x_hat = xbar * rstd
else:
x_hat = X * rstd
Y = x_hat * W
if HAS_BIAS:
Y = Y + B
if STORE_RESIDUAL_OUT:
return Y, mean, rstd, RESIDUAL_OUT
else:
return X
with torch.cuda.device(x.device.index):
if residual_out is not None:
y, mean, rstd, residual_out = _layer_norm_fwd_1pass_kernel(
x,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
bias is not None,
)
else:
x = _layer_norm_fwd_1pass_kernel(
x,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
bias is not None,
)
site-package/mamba_ssm/ops/triton/selective_state_update.py中的
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
# Matrix dimensions
batch, dim, dstate,
# Strides
stride_state_batch, stride_state_dim, stride_state_dstate,
stride_x_batch, stride_x_dim,
stride_dt_batch, stride_dt_dim,
stride_dt_bias_dim,
stride_A_dim, stride_A_dstate,
stride_B_batch, stride_B_dstate,
stride_C_batch, stride_C_dstate,
stride_D_dim,
stride_z_batch, stride_z_dim,
stride_out_batch, stride_out_dim,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
state_ptr += pid_b * stride_state_batch
x_ptr += pid_b * stride_x_batch
dt_ptr += pid_b * stride_dt_batch
B_ptr += pid_b * stride_B_batch
C_ptr += pid_b * stride_C_batch
if HAS_Z:
z_ptr += pid_b * stride_z_batch
out_ptr += pid_b * stride_out_batch
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
dA = tl.exp(A * dt[:, None])
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None]
state = state * dA + dB * x[:, None]
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state, x, dt, dt_bias, A, B, C, D, z, out,
batch, dim, dstate,
state.stride(0), state.stride(1), state.stride(2),
x.stride(0), x.stride(1),
dt.stride(0), dt.stride(1),
dt_bias.stride(0) if dt_bias is not None else 0,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
D.stride(0) if D is not None else 0,
z_strides[0], z_strides[1],
out.stride(0), out.stride(1),
dt_softplus,
BLOCK_SIZE_M,
# num_warps=num_warps,
)
替换为
def selective_scan_update_kernel(# Pointers to matrices
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
# Matrix dimensions
batch, dim, dstate,
# Strides
stride_state_batch, stride_state_dim, stride_state_dstate,
stride_x_batch, stride_x_dim,
stride_dt_batch, stride_dt_dim,
stride_dt_bias_dim,
stride_A_dim, stride_A_dstate,
stride_B_batch, stride_B_dstate,
stride_C_batch, stride_C_dstate,
stride_D_dim,
stride_z_batch, stride_z_dim,
stride_out_batch, stride_out_dim,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
# HAS_DT_BIAS: tl.constexpr,
# HAS_D: tl.constexpr,
# HAS_Z: tl.constexpr,
# BLOCK_SIZE_DSTATE: tl.constexpr,
):
# Assuming state, x, dt, A, B, C, D, z, and dt_bias are PyTorch tensors with appropriate shapes
batch, dim, dstate = state_ptr.shape
HAS_DT_BIAS = dt_bias_ptr is not None
HAS_D = D_ptr is not None
HAS_Z = z_ptr is not None
if HAS_DT_BIAS:
dt_ptr = dt_ptr + dt_bias_ptr
if DT_SOFTPLUS:
dt_ptr = torch.where(dt_ptr <= 20.0, torch.log1p(torch.exp(dt_ptr)), dt_ptr)
dA = torch.exp(A_ptr * dt_ptr.unsqueeze(-1))
dB = B_ptr * dt_ptr.unsqueeze(-1)
state_ptr = state_ptr * dA + dB * x_ptr.unsqueeze(-1)
out_ptr = torch.sum(state_ptr * C_ptr.unsqueeze(0), dim=-1)
if HAS_D:
out_ptr = out_ptr + x_ptr * D_ptr
if HAS_Z:
out_ptr = out_ptr * z_ptr * torch.sigmoid(z_ptr)
return out_ptr
with torch.cuda.device(x.device.index):
out = selective_scan_update_kernel(
state, x, dt, dt_bias, A, B, C, D, z, out,
batch, dim, dstate,
state.stride(0), state.stride(1), state.stride(2),
x.stride(0), x.stride(1),
dt.stride(0), dt.stride(1),
dt_bias.stride(0) if dt_bias is not None else 0,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
D.stride(0) if D is not None else 0,
z_strides[0], z_strides[1],
out.stride(0), out.stride(1),
dt_softplus,
BLOCK_SIZE_M,
# num_warps=num_warps,
)