反向传播y=wx+b
# FILE: 学习深度学习/Back_Propagation
# USER: mcfly
# IDE: PyCharm
# CREATE TIME: 2024/9/2 18:46
# DESCRIPTION: Back Propagation
import torch
# 训练数据
x_train = [float(i) for i in range(1, 6)]
y_train = [2.0 * i + 3.0 for i in range(1, 5)]
def forward(x):
return w * x + b # 这里返回的是一个Tensor
def cal_loss(y, y_pred):
return (y - y_pred) ** 2
# 参数
w = torch.Tensor([1.0]) # 一个1*1的Tensor
b = torch.Tensor([1.0])
eta = 0.01
w.requires_grad = True
b.requires_grad = True
for epoch in range(1000): # 100次迭代
print("{}-th:".format(epoch))
for x, y in zip( x_train, y_train ):
y_pred = forward(x)
loss = cal_loss(y, y_pred)
loss.backward() # 反向传播
w.data = w.data - eta * w.grad.data
w.grad.data.zero_() # 清零防止下一轮时仍保存上一轮loss对w的偏导
b.data = b.data - eta * b.grad.data
b.grad.data.zero_()
print("\tw:{0}\n\tb:{1}\n\tloss:{2}".format(w.data.item(), b.data.item(), loss.data.item()))
print( "w,b: ", w.data, b.data )
y = w 1 x 2 + w 2 x + b 就是多加一个参数 y = w_1x^2+w_2x+b就是多加一个参数 y=w1x2+w2x+b就是多加一个参数
# FILE: 学习深度学习/Back_Propagation
# USER: mcfly
# IDE: PyCharm
# CREATE TIME: 2024/9/2 18:46
# DESCRIPTION: Back Propagation
import torch
# 训练数据
x_train = [float(i) for i in range(-2, 3)]
y_train = [2*i*i+3*i+1 for i in range(-2, 3)]
def forward(x):
return w1 * x*x + w2*x + b # 这里返回的是一个Tensor
def cal_loss(y, y_pred):
return (y - y_pred) ** 2
# 参数
w1 = torch.Tensor([1.0]) # 一个1*1的Tensor
w2 = torch.Tensor([1.0])
b = torch.Tensor([1.0])
eta = 0.01
w1.requires_grad = True
w2.requires_grad = True
b.requires_grad = True
for epoch in range(1000): # 100次迭代
print("{}-th:".format(epoch))
for x, y in zip( x_train, y_train ):
y_pred = forward(x)
loss = cal_loss(y, y_pred)
loss.backward() # 反向传播
w1.data = w1.data - eta * w1.grad.data
w1.grad.data.zero_() # 清零防止下一轮时仍保存上一轮loss对w的偏导
w2.data = w2.data - eta * w2.grad.data
w2.grad.data.zero_()
b.data = b.data - eta * b.grad.data
b.grad.data.zero_()
print("\tw1:{0}, w2:{1}, b:{2}\n\tloss:{3}".format(w1.data, w2.data, b.data, loss.data))
print( "w1:{}\nw2:{}\nb:{}".format( w1.data, w2.data, b.data ) )
至于loss在作用域外能被访问,只能说是语言特性了,并非因为其为Tensor变量