pytorch 中 torch.sqrt 的坑

记录一下,避免后面踩坑

单步调了半小时,幸好这个问题出现很频繁
发现是函数与导函数定义域问题。。。

sqrt(x) 函数的定义域为 [0, 无穷大)
sqrt(x) 的导函数的定义域 却是 (0, 无穷大)

这些函数定义域跟导函数的定义域不一样,正向传播可以得到正常结果,但是一旦backward就会得到Nan。。。

问题重现

import torch
a = torch.zeros(1)
a.requireds_grad = True
b = torch.sqrt(a)
b.backward()
print(a.grad)
# 得到nan

如何解决
让输入的值符合sqrt的导函数定义域就可以解决该问题了。举个例子:设 x 的定义域为 [0, 无穷大) ,给 x 加个很小的数,例如1e-8,使其输入值的定义域略微往右偏移,就可以避开 0 这个未定义值了;y = sqrt(x + 1e-8)

  • 28
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值