【PyTorch深度学习实践 P4】反向传播

反向传播y=wx+b

# FILE: 学习深度学习/Back_Propagation
# USER: mcfly
# IDE: PyCharm
# CREATE TIME: 2024/9/2 18:46
# DESCRIPTION: Back Propagation

import torch

# 训练数据
x_train = [float(i) for i in range(1, 6)]
y_train = [2.0 * i + 3.0 for i in range(1, 5)]



def forward(x):
    return w * x + b # 这里返回的是一个Tensor

def cal_loss(y, y_pred):
    return (y - y_pred) ** 2

# 参数
w = torch.Tensor([1.0]) # 一个1*1的Tensor
b = torch.Tensor([1.0])
eta = 0.01

w.requires_grad = True
b.requires_grad = True

for epoch in range(1000): # 100次迭代
    print("{}-th:".format(epoch))
    for x, y in zip( x_train, y_train ):
        y_pred = forward(x)
        loss = cal_loss(y, y_pred)
        loss.backward() # 反向传播

        w.data = w.data - eta * w.grad.data
        w.grad.data.zero_() # 清零防止下一轮时仍保存上一轮loss对w的偏导
        b.data = b.data - eta * b.grad.data
        b.grad.data.zero_()
    print("\tw:{0}\n\tb:{1}\n\tloss:{2}".format(w.data.item(), b.data.item(), loss.data.item()))

print( "w,b: ", w.data, b.data )

y = w 1 x 2 + w 2 x + b 就是多加一个参数 y = w_1x^2+w_2x+b就是多加一个参数 y=w1x2+w2x+b就是多加一个参数

# FILE: 学习深度学习/Back_Propagation
# USER: mcfly
# IDE: PyCharm
# CREATE TIME: 2024/9/2 18:46
# DESCRIPTION: Back Propagation

import torch

# 训练数据
x_train = [float(i) for i in range(-2, 3)]
y_train = [2*i*i+3*i+1 for i in range(-2, 3)]



def forward(x):
    return w1 * x*x + w2*x + b # 这里返回的是一个Tensor

def cal_loss(y, y_pred):
    return (y - y_pred) ** 2

# 参数
w1 = torch.Tensor([1.0]) # 一个1*1的Tensor
w2 = torch.Tensor([1.0])
b = torch.Tensor([1.0])
eta = 0.01

w1.requires_grad = True
w2.requires_grad = True
b.requires_grad = True

for epoch in range(1000): # 100次迭代
    print("{}-th:".format(epoch))
    for x, y in zip( x_train, y_train ):
        y_pred = forward(x)
        loss = cal_loss(y, y_pred)
        loss.backward() # 反向传播

        w1.data = w1.data - eta * w1.grad.data
        w1.grad.data.zero_() # 清零防止下一轮时仍保存上一轮loss对w的偏导
        w2.data = w2.data - eta * w2.grad.data
        w2.grad.data.zero_()
        b.data = b.data - eta * b.grad.data
        b.grad.data.zero_()
    print("\tw1:{0}, w2:{1}, b:{2}\n\tloss:{3}".format(w1.data, w2.data, b.data, loss.data))

print( "w1:{}\nw2:{}\nb:{}".format( w1.data, w2.data, b.data ) )

至于loss在作用域外能被访问,只能说是语言特性了,并非因为其为Tensor变量
请添加图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值