pytorch的backward函数用法
首先看一个简单的程式:import torchx = torch.tensor([3, 2], dtype=torch.float32, requires_grad=True)y = x ** 2out = y.mean()out.backward()print(x.grad)输出的结果是:tensor([3., 2.])为什么是这个呢?简单的求导一下就容易理解了。所以,按照这样一个思路求导下来的结果,out对x的梯度,就是x的值。backward还可以传一个参数。
原创
2020-07-28 00:43:56 ·
3159 阅读 ·
4 评论