MindSpore 如何实现像Torch里面 retain_graph=True 的功能

问题描述

假设首先通过数据x标签y及 mindspore.gard 获取到了模型的梯度g

现在需要将梯度g进行一定的运算(假设为 f(g)),最后 f(g) 需要对最初的 数据x 进行求导,mindspore 该如何实现呢?

如果使用 Torch 的话,在求 梯度g 的时候只要 retain_graph=True, f(g) 就可以直接对 数据x 进行求导。

但是有大佬知道 MindSpore 该如何实现吗?

解决方案

在第一次求梯度g的时候,不调用MindSpore的 NLLLoss() 即可(然而 MindSpore 自带的交叉熵损失也自动调用了 NLLLoss(),因此也需要重写 )。

通过自己写负对数似然损失nll_loss,成功不报错。

%E6%88%AA%E5%B1%8F2024-03-16%2017.22.25.png

  • 9
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值