masked_select算子在模型训练时的异常

在训练我的网络时, 我自定义了一个损失函数, 代码如下:

from mindspore import nn

class Loss(nn.Cell):
    def __init__(self, backbone, loss_fn):
        super(Loss, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self._loss_fn = loss_fn

    def construct(self, data, label):
        ix = label == label
        output = self._backbone(data)
        output = output.masked_select(ix)
        label = label.masked_select(ix)
        loss = self._loss_fn(output, label)
        return loss

    def construct(self, data, label):
        out = self._backbone(data)
        return self._loss_fn(out, label)

在训练时,就会报如下错误:

​但如果把以下代码改掉:

output = output.masked_select(ix)
label = label.masked_select(ix)

改成

output = output.masked_fill(ix, 0)
label = label.masked_select(ix, 0)

错误就没有了, 请问关于masked_select的异常是怎么回事呢? 这个该怎么解决呢?

当然, 如果可以用bool型的Tensor作为索引的话, 这个问题就不存在了, 我就可以直接 output = output[ix] 了.

****************************************************解答*****************************************************

如果是Ascend上的话,我建议使用Primitive接口P.MaskedSelect()和使用相同的数据定位一下。

P.MaskedSelect()是最底层的接口,排除一下是不是算子本身的问题

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值