问题描述
ops.gather()函数和torch.gather()函数
首先这两个函数的异同并没有给出在api映射文档中,于是我默认两者一致
但是在进行模型迁移的时候,我发现ops.gather()的结果与torch.gather()的结果出入很大
我此处给出一个测试样例(mindspore=2.2.14,torch=1.13.1+cu116)
import numpy as np
cumwidths = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]])
bin_idx = np.array([[[0, 1, 2, 0], [2, 0, 1, 3], [3, 3, 0, 2]],
[[1, 2, 3, 0], [0, 1, 2, 3], [3, 0, 1, 2]]])
## test mindspore
# import mindspore
# cumwidths = mindspore.Tensor(cumwidths)
# bin_idx = mindspore.Tensor(bin_idx)
# tmp = cumwidths.gather(bin_idx, axis=-1)
# input_cumwidths = tmp[...,0]
# test torch
import torch
cumwidths = torch.tensor(cumwidths)
bin_idx = torch.tensor(bin_idx)
tmp = cumwidths.gather(-1, bin_idx)
input_cumwidths = tmp[...,0]
# output
print("cumwidths shape:", cumwidths.shape)
print("bin_idx shape:", bin_idx.shape)
print("tmp shape:",tmp.shape)
print("input_cumwidths shape:", input_cumwidths.shape)
print(cumwidths)
print(bin_idx.shape)
print(tmp)
print(input_cumwidths)
多的输出我就不放了,这是两者的运行结果
# torch
cumwidths shape: torch.Size([2, 3, 4])
bin_idx shape: torch.Size([2, 3, 4])
tmp shape: torch.Size([2, 3, 4])
input_cumwidths shape: torch.Size([2, 3])
# mindspore
cumwidths shape: (2, 3, 4)
bin_idx shape: (2, 3, 4)
tmp shape: (2, 3, 2, 3, 4)
input_cumwidths shape: (2, 3, 2, 3)
对于这个问题,我目前的解决办法使用np+mindspore.tensor重新实现了一个简易的torch.gather
索引问题
这个问题我不太清楚怎么描述,直接看测试代码吧 torch的测试代码:
import numpy as np
import torch
input = torch.Tensor(np.random.randn(2, 1, 7, 11))
val = torch.Tensor(np.zeros((2, 1, 7)))
bound = 1
mask = (val >= -bound) & (val <= bound)
print(mask.shape)
print(input[mask,:].shape)
mindspore的测试代码(只是把上面的torch换成了mindspore,其他没有任何变化)
import numpy as np
import mindspore
input = mindspore.Tensor(np.random.randn(2, 1, 7, 11))
val = mindspore.Tensor(np.zeros((2, 1, 7)))
bound = 1
mask = (val >= -bound) & (val <= bound)
print(mask.shape)
print(input[mask,:].shape)
两者的输出对比,很明显这两个的形状不一样,我不清楚是上述代码的哪一步导致的
# torch
torch.Size([2, 1, 7])
torch.Size([14, 11])
# mindspore
(2, 1, 7)
(14, 1)
以上就是我的全部问题,还请大佬们指正(另外好像运行测试程序的时候,torch的程序要快不少)
解答方案
请使用gather_d算子,此算子与torch一致。
gather是和TensorFlow对标的,不是和torch对标。
https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/ops/mindspore.ops.gather_d.html