【MindSpore】【源码分析】StridedSlice 逻辑与TensorFlow 不一致及算子替换方案

转载地址:https://bbs.huaweicloud.com/forum/thread-98897-1-1.html

作者: 杨德志

【代码模块】

mindspore.ops.StridedSlice

【问题描述 & 代码分析】

在用MindSpore写模型的时候,遇到StridedSlice对mask参数的处理与TensorFlow不一致的问题。

MindSpore也没有给详细的说明,于是扒了扒MindSpore和TensorFlow的源码,以 new_axis_mask 参数为例。

1. MindSpore里的处理:

有多少位被置1,就在最前面位置增加多少个长度为1的维度。

https://gitee.com/mindspore/mindspore/blob/master/mindspore/ops/operations/array_ops.py

class StridedSlice => def _compute_slicing_shape

关于new_axis_mask 参数没有明确的注释说明是怎么运作的,关键的代码片段如下:

if j < len(new_axis_pos) and new_axis_pos[j] == '1':

    ret_shape.append(1)

    j += 1

    continue

2. TensorFlow里的注释说明:

如果第 i 位被置1,会在对应的位置增加一维,长度为1。

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py

参数的注释说明如下:

If the ith bit of `new_axis_mask` is set, then `begin`, 

`end`, and `stride` are ignored and a new length 1 dimension is 

added at this point in the output tensor.

【算子替换方案】

从逻辑上,两者对new_axis_mask 做了不同的处理,如果需要保持逻辑一致,可以用 expand_dims 来替代 new_axis_mask 的工作。

截图就略过了,很好复现,比如可以把new_axis_mask设置成4,他们会在不同的位置增加1个维度。

其他的mask其实也发现了一些问题,暂时不建议在MindSpore里用这五个mask参数(begin_maskend_maskellipsis_masknew_axis_maskshrink_axis_mask)。

可以修改beginendstrides,也试试用别的算子替代。

issue:https://gitee.com/mindspore/mindspore/issues/I2BKFK

华为邮箱:yangdezhi3@huawei.com

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值