【PyTorch深度学习实践 P3】梯度下降

y=wx+b的简单实现
如果没得到w,b=2,1可能是学习率选取不当

# FILE: 学习深度学习/Gradient_Descent
# USER: mcfly
# IDE: PyCharm
# CREATE TIME: 2024/9/2 10:09
# DESCRIPTION: 梯度下降 GD


import numpy as np
from matplotlib import pyplot as plt

x_train = [1.0, 2.0, 3.0, 4.0 ]
y_train = [3.0, 5.0, 7.0, 9.0 ]

def forward( w, b, x):
    return w * x + b

def cal_loss( y, y_head=0.0):
    return (y-y_head) * (y-y_head)

def cal_cost( ylist,y_headlist ):
    sum = 0.0
    for y, y_head in zip( ylist, y_headlist ):
        sum += cal_loss( y, y_head )
    sum /= len(ylist)
    return sum

def gradient_w( w, b, xlist=[], ylist=[]):
    ret = 0.0
    for x, y in zip(xlist, ylist):
        ret += 2 * x * ( x*w+b-y)
    ret /= len(xlist)
    return ret

def gradient_b( w, b, xlist=[], ylist=[]):
    ret = 0.0
    for x, y in zip(xlist, ylist):
        ret += 2 * ( x*w+b-y)
    ret /= len(xlist)
    return ret

w = -2.0
b = 0.0
eta = 0.05 # learning rate
wlist = []
blist = []
cost_dict = {}

for i in range( 1000 ):
    print( "{}-th:".format(i+1))
    y_headlist = [forward(w,b,x) for x in x_train] // 预测值
    print( "\tPrediction: {}".format(y_headlist))
    
    cost = cal_cost(y_train,y_headlist) 
    wlist.append(w)
    blist.append(b)
    cost_dict[(w,b)] = cost
    print("\tw:{0}\n\tb:{1}\n\tloss:{2}".format(w, b, cost))
    
    gw = gradient_w(w, b, x_train, y_train) 
    gb = gradient_b(w, b, x_train, y_train)
    w -= eta * gw
    b -= eta * gb

local_min = min(cost_dict, key=lambda x:cost_dict[x]) // 在字典中找到cost最小的(w,b)
print( "{} is the w,b that minimize the loss.".format(local_min))



x = [key[0] for key in cost_dict.keys()]
y = [key[1] for key in cost_dict.keys()]
z = list(cost_dict.values())

fig = plt.figure(num="Loss") # 创建一个画布figure,然后在这个画布上加各种元素。
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z, c='b', marker='o')

ax.set_xlabel('w') # 画出坐标轴
ax.set_ylabel('b')
ax.set_zlabel('loss')

plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值