【问题描述】
PyTorch中的Scatter算子与MindSpore.ops中的ScatterNd算子并不能一一对应
1、 在PyTorch中, index矩阵中的位置和具体值形成实际的index, 然后将src中的值依据实际的index来写到self中

2、 在MindSpore中, ScatterNd需要的就是实际的index (二维矩阵), 即:
self[index[i, 0]][index[i, 1]]index[i, 2] = src[i]
【转化方法】
如果用MindSpore的方法实现PyTorch的scatter算子,主要是对index进行相互转换. 目前没有太高效的方法,个人的实现方法如下:
def broadcast(src: ms.Tensor, axis:int):
src = src.asnumpy()
ix = np.argwhere(src == src)
src = src.reshape(-1)
ix[:, axis] = src
return ms.Tensor(ix)
def scatter_(src: ms.Tensor, index: ms.Tensor, out: ms.Tensor, axis: int=-1):
index = broadcast(index, axis)
op = ops.TensorScatterUpdate()
return op(out, index, src.reshape(-1))
【问题现象】
目前在broadcast方法中,numpy接口并没有提供类似原生接口中的argwhere方法, 并且where方法与原生numpy中的方法也并不一样,所以在这里只能从ms.Tensor转换成np.Arrray进行操作.
请问如何才能高效实现Scatter算子的转换, 或者怎样用Mindspore实现numpy中类似where或argwhere的方法?
对标MindSpore 的 ScatterElement算子。 计划1.8.1版本支持
本文探讨了PyTorch中的Scatter算子与MindSpore.ops中的ScatterNd算子的区别,并提供了使用MindSpore实现PyTorch scatter功能的转换方法。在转换过程中遇到的问题是MindSpore缺少类似numpy argwhere的功能,导致效率不高。文章最后提到了MindSpore计划在1.8.1版本支持ScatterElement算子。

2469

被折叠的 条评论
为什么被折叠?



