学习视频:B站 刘二大人《PyTorch深度学习实践》完结合集
三、反向传播
前馈计算
对于比较复杂的网络:
两层神经网络:化简后仍然是线性
引入激活函数,增加非线性
前向传播:
反向传播
利用链式法则计算梯度:
反向传播过程:
代码实现
线性模型
import torch
#已知数据
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
#线性模型为y=wx,预测x=4时,y的值
#假设w = 1
w = torch.Tensor([1.0])
w.requires_grad = True #设为true计算梯度
#定义模型
def forward(x):
return x * w
def loss(x,y): # 计算损失函数
y_pred = forward(x)
return (y_pred - y) ** 2
print("predict(before training",4,'%.2f'%(forward(4)))
for epoch in range(100):
for x,y in zip(x_data,y_data):
l = loss(x,y)
l.backward() # 对requires_grad = True的Tensor(w)计算其梯度并进行反向传播,并且会释放计算图进行下一次计算
print("\tgrad:,%.1f %.1f %.2f" % (x,y,w.grad.item()))
w.data = w.data - 0.01 * w.grad.data #w.grad.data获得w,通过梯度对w进行更新
w.grad.data.zero_() #将权重里面的梯度清零
print("Epoch:%d, w = %.2f, loss = %.2f" % (epoch,w,l.item()))
print("predict(after training",4,'%.2f'%(forward(4)))
运行结果:
- 算法反向传播体现在,l.backward(),调用该方法后w.grad由none更新为Tensor类型,且w.grad.data的值用于后续w.data的更新
- l.backward()会把计算图中所有需要梯度(grad)的地方都会求出来,然后把梯度都存在对应的待求的参数中,最终计算图被释放。
- 取tensor中的data是不会构建计算图的。
课后练习
1.计算y = wx的梯度
2.计算y = wx + b的梯度
3.计算 y = w 1 x 2 + w 2 x + b y=w_1x^2 + w_2x+b y=w1x2+w2x+b的梯度
用pytorch代码实现:
代码:
import torch
#已知数据
x_data = [1.0,2.0,3.0]
y_data = [6.0,11.0,18.0]
#线性模型为y=w1x^2 + w2x +b,预测x=4时,y的值
#假设w = 1,b=1
w1 = torch.Tensor([1.0])
w1.requires_grad = True #设为true计算梯度
w2 = torch.Tensor([1.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = True
#定义模型
def forward(x):
return w1*x*x + w2*x +b
def loss(x,y): # 计算损失函数
y_pred = forward(x)
return (y_pred - y) ** 2
print("predict(before training)",4,'%.2f'%(forward(4)))
for epoch in range(100):
for x,y in zip(x_data,y_data):
l = loss(x,y)
l.backward() # 对requires_grad = True的Tensor(w)计算其梯度并进行反向传播,并且会释放计算图进行下一次计算
w1.data = w1.data - 0.01 * w1.grad.data #w.grad.data获得w,通过梯度对w进行更新
w2.data = w2.data - 0.01 * w2.grad.data
b.data = b.data - 0.01 * b.grad.data
w1.grad.data.zero_() #将权重里面的梯度清零
w2.grad.data.zero_()
b.grad.data.zero_()
print("Epoch:%d, w1 = %.2f,w2 = %.2f,b = %.2f loss = %.2f" % (epoch,w1,w2,b,l.item()))
print("predict(after training)",4,'%.2f'%(forward(4)))
运行结果:
参考资料
https://blog.csdn.net/qq_43800119/article/details/126415332