计算图的概念
这里可以发现,无论加多少个隐藏层, 都是线性拟合函数,因此模型的拟合性能很差,由此引入了非线性函数
非线性函数
backpropagation步骤:
- 创建计算图(前馈计算)
- 计算local gradient(在forward过程中就计算)
- 当从输入计算到最终Loss时,开始回传,一级级回传偏导数
- 利用链式求导法则求出Loss相对于输入的偏导数
即:bp = 前馈+反馈,此处可以结合老师上课讲的例子理解前馈与反馈过程
几个细节:为什么这里要用w.data, w.grad.data? 因为加上data之后才被视为是数值运算,否则会因为是向量而构造计算图!!
没做完一次bp取到梯度之后,要对梯度清零!否则梯度会累加!!
课上例题代码:
import torch
# define dataset
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.tensor([1.0],requires_grad=True)
def forward(x):
return x * w
# 构造计算图的过程
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
print("Predict (before)", 4, forward(4).item()) # 注意这里的item!
for epoch in range(100):
for x, y in zip(x_data, y_data):
l = loss(x, y) # l是一个张量,tensor主要是在建立计算图 forward, compute the loss
l.backward() # backward,compute grad for Tensor whose requires_grad set to True
print('\tgrad:', x, y, w.grad.item())
w.data = w.data - 0.01 * w.grad.data # 权重更新时,注意grad也是一个tensor
w.grad.data.zero_() # after update, remember set the grad to zero
print('progress:', epoch, l.item()) # 取出loss使用l.item,不要直接使用l(l是tensor会构建计算图)
print("Predict (after)", 4, forward(4).item()) # 注意这里的item!
课后习题代码
import torch
# define dataset
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w1 = torch.tensor([1.0], requires_grad=True)
w2 = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([1.0], requires_grad=True)
def forward(x):
return (x**2) * w1 + x * w2 + b
# 构造计算图的过程
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
print("Predict (before)", 4, forward(4).item()) # 注意这里的item!
for epoch in range(100):
for x, y in zip(x_data, y_data):
l = loss(x, y) # l是一个张量,tensor主要是在建立计算图 forward, compute the loss
l.backward() # backward,compute grad for Tensor whose requires_grad set to True
print('\tgrad:', x, y, w1.grad.item(), w2.grad.item(), b.grad.item())
w1.data = w1.data - 0.01 * w1.grad.data # 权重更新时,注意grad也是一个tensor
w1.grad.data.zero_() # after update, remember set the grad to zero
w2.data = w2.data - 0.01 * w2.grad.data
w2.grad.data.zero_()
b.data = b.data - 0.01 * b.grad.data
b.grad.data.zero_()
print('progress:', epoch, l.item()) # 取出loss使用l.item,不要直接使用l(l是tensor会构建计算图)
print("Predict (after)", 4, forward(4).item()) # 注意这里的item!
计算图可以参考:课后题计算图