解决RuntimeError: “triu_tril_cuda_template“ not implemented for ‘BFloat16‘

注:该方法仅在qwen2-vl实验过,不知道是否通用,仅供参考,修改源文件前建议备份,如果不起作用方便还原

首先先找到报错的地方,例如作者的位置:

打开这个报错的文件,定位到该行代码,然后使用pytorch重新实现一个torch.triu函数

def custom_triu(input_tensor, diagonal=0):
    """
    返回一个与input_tensor相同形状的新张量,其中包含input_tensor的上三角部分,
    其余部分填充为0。diagonal参数决定了上三角部分的定义位置。
    
    参数:
    input_tensor (torch.Tensor): 输入张量。
    diagonal (int, 可选): 对角线偏移。默认为0(主对角线)。
    
    返回:
    torch.Tensor: 包含上三角部分的新张量。
    """
    # 获取输入张量的行数和列数
    rows, cols = input_tensor.size()
    
    # 创建一个与input_tensor相同形状的掩码
    # 使用torch.arange来生成行和列的索引,然后进行广播比较
    row_indices = torch.arange(rows, device=input_tensor.device).view(-1, 1)
    col_indices = torch.arange(cols, device=input_tensor.device).view(1, -1)
    mask = row_indices + diagonal < col_indices
    
    # 使用掩码选择上三角部分
    result_tensor = torch.where(mask, torch.tensor(0, device=input_tensor.device, dtype=input_tensor.dtype), input_tensor)
    
    return result_tensor

替换出错的地方,将causal_mask = torch_triu(causal_mask, diagonal=1)修改为causal_mask = custom_triu(causal_mask, diagonal=1),保存文件即可。

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值