深度学习模型移植-替换gather算子

该文章介绍了一种在不支持复杂算子如gather的情况下,利用torch进行矩阵计算和简单操作来替代gather的方法。提供的torch_max_replace_gather函数展示了如何通过指数运算和张量操作实现类似功能,适用于边缘计算场景。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

很多边缘计算芯片不支持复杂的算子,比如gather算子。

这就需要使用简单的矩阵计算以及一些简单的torch操作来进行替换。

下面分享一些我在工作中用到的,能够替换gather的函数:

def torch_max_replace_gather(att_weights_prob, dim, ind_k):
    b,c,d,h,w = att_weights_prob.shape
    b,c,d1,h,w = ind_k.shape
    att_weights_prob = att_weights_prob.transpose(4,dim).reshape(b*c*h*w,d).unsqueeze(1).expand(-1,d1,-1)
    att_topk1 = ind_k.transpose(4,dim).reshape(b*c*h*w,d1).unsqueeze(2).expand(-1,-1,d)
    index_array = torch.arange(d).unsqueeze(0).unsqueeze(0).repeat(b*c*h*w,d1,1).to(att_weights_prob.device)
    att_topk1 = index_array - att_topk1
    att_topk1 =att_topk1*att_topk1
    att_topk1[att_topk1>0]=1
    att_topk1 = 1 - att_topk1
    min_val= torch.min(att_weights_prob, dim=dim)[0] - 0.1
    att_weights_prob = att_weights_prob - min_val.unsqueeze(2)
    att_weights_prob = att_weights_prob*att_topk1
    max_val = torch.max(att_weights_prob, dim=dim)[0]
    att_topk1 = (max_val + min_val).reshape(b,c,w,h,d1).transpose(4,2).transpose(3,4).transpose(3,4)
    return att_topk1

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值