【笔记】Pytorch-backward()

本文是通过学习 AI研习社 炼丹兄 所做笔记!!!

1. .backward()

目录

1.1 .grad

1.2 .is_leaf

1.3 .retain_grad()

1.4 .grad_fn

1.1 .grad

作用:查看叶子节点梯度值

 

例:y=a*b, a=x+w, b=w+1,计算y关于w的梯度:

\frac{\partial y}{\partial w}=\frac{\partial y}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial y}{\partial b}\frac{\partial b}{\partial w}=2\times w+x+1

(上式计算中,w=1,x=2)

 w 和 x 是叶子节点,是整个计算过程的根基

在反向传播结束之后,非叶子节点的梯度会被释放掉

Pytorch代码实现过程如下:

import torch
w = torch.tensor([1.],requires_grad = True)
x = torch.tensor([2.],requires_grad = True)

a = w + x
b = w + 1
y = a * b

y.backward()
print(w.grad)

运行结果如下:

tensor([5.0])

1.2 .is_leaf

作用:查看是否为叶子节点

例:y=a*b, a=x+w, b=w+1,计算y关于w的梯度:

\frac{\partial y}{\partial w}=\frac{\partial y}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial y}{\partial b}\frac{\partial b}{\partial w}=2\times w+x+1

(上式计算中,w=1,x=2)

import torch
w = torch.tensor([1.],requires_grad = True)
x = torch.tensor([2.],requires_grad = True)

a = w + x
b = w + 1
y = a * b

y.backward()
print(w.is_leaf,x.is_leaf,a.is_leaf,b.is_leaf,y.is_leaf)
print(w.grad,x.grad,a.grad,b.grad,y.grad) 

 运行结果如下:

True True False False False
tensor([5.]) tensor([2.]) None None None

观察到只有 w 和 x 是叶子节点,然后反向传播计算完梯度后(.backward()之后),只有叶子节点的梯度保存下来了。

 1.3 .retain_grad()

作用:通过.retain_grad()保留非任意节点的梯度值

例:y=a*b, a=x+w, b=w+1,计算y关于w的梯度:

\frac{\partial y}{\partial w}=\frac{\partial y}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial y}{\partial b}\frac{\partial b}{\partial w}=2\times w+x+1

(上式计算中,w=1,x=2)

import torch
w = torch.tensor([1.],requires_grad = True)
x = torch.tensor([2.],requires_grad = True)

a = w + x
a.retain_grad()
b = w + 1
y = a * b

y.backward()
print(w.is_leaf,x.is_leaf,a.is_leaf,b.is_leaf,y.is_leaf)
print(w.grad,x.grad,a.grad,b.grad,y.grad)

运行结果如下:

True True False False False
tensor([5.]) tensor([2.]) tensor([2.]) None None

 1.4 .grad_fn

作用:记录创建该张量时所用的函数,这个属性反向传播的时候会用到

例:y=a*b, a=x+w, b=w+1,计算y关于w的梯度:

\frac{\partial y}{\partial w}=\frac{\partial y}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial y}{\partial b}\frac{\partial b}{\partial w}=2\times w+x+1

(上式计算中,w=1,x=2)

 y.grad_fn = MulBackward0 表示y是通过乘法得到的。求导时用乘法的求导法则。

 a.grad_fn = AddBackward0 表示a是通过加法得到的。求导时用加法的求导法则。

        叶子节点的 .grad_f是None

import torch

w = torch.tensor([1.],requires_grad = True)
x = torch.tensor([2.],requires_grad = True)

a = w + x
a.retain_grad()
b = w + 1
y = a * b

y.backward()
print(y.grad_fn)
print(a.grad_fn)
print(w.grad_fn)

运行结果如下:

<MulBackward0 object at 0x000002107EE17610>
<AddBackward0 object at 0x000002107EE17610>
None
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

客官可满意

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值