类似PyTorch中的Scatter算子的实现

本文探讨了PyTorch中的Scatter算子与MindSpore.ops中的ScatterNd算子的区别,并提供了使用MindSpore实现PyTorch scatter功能的转换方法。在转换过程中遇到的问题是MindSpore缺少类似numpy argwhere的功能,导致效率不高。文章最后提到了MindSpore计划在1.8.1版本支持ScatterElement算子。

【问题描述】

PyTorch中的Scatter算子与MindSpore.ops中的ScatterNd算子并不能一一对应

1、 在PyTorch中, index矩阵中的位置和具体值形成实际的index, 然后将src中的值依据实际的index来写到self中

2、 在MindSpore中, ScatterNd需要的就是实际的index (二维矩阵), 即:

self[index[i, 0]][index[i, 1]]index[i, 2] = src[i]

【转化方法】

如果用MindSpore的方法实现PyTorch的scatter算子,主要是对index进行相互转换. 目前没有太高效的方法,个人的实现方法如下: 

def broadcast(src: ms.Tensor, axis:int):
    src = src.asnumpy()
    ix = np.argwhere(src == src)
    src = src.reshape(-1)
    ix[:, axis] = src
    return ms.Tensor(ix)
def scatter_(src: ms.Tensor, index: ms.Tensor, out: ms.Tensor, axis: int=-1):
    index = broadcast(index, axis)
    op = ops.TensorScatterUpdate()
    return op(out, index, src.reshape(-1))

【问题现象】

目前在broadcast方法中,numpy接口并没有提供类似原生接口中的argwhere方法, 并且where方法与原生numpy中的方法也并不一样,所以在这里只能从ms.Tensor转换成np.Arrray进行操作.

请问如何才能高效实现Scatter算子的转换, 或者怎样用Mindspore实现numpy中类似where或argwhere的方法?

对标MindSpore 的 ScatterElement算子。 计划1.8.1版本支持

 

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值