PyTorch学习笔记梯度下降法代码

一元一次方程求解

求解一元一次方程,如y=wx+b,其中x,y数据都是存在噪声的,要想求得该方程的参数,w和b值,需要求解成百组乃至上千组表现好的方程组。

实战代码流程:

数据点生成

随机生成100组数据点,代码如下:

import numpy as np
from numpy import random
point=np.random.normal(50,2.5,size=(100,2))
np.savetxt("data.csv",point,delimiter=',')

误差公式

定义的Loss公式的函数如下:
l o s s = ∑ ( W X + b − y ) 2 / n loss = \sum(WX+b-y)^{2} /n loss=(WX+by)2/n

所定义的loss函数代码如下:

def compute_error_for_line_given_points(b,w,points):
    totalError = 0
    for i in range(0,len(points)):
        x=points[i,0]
        y=points[i,1]
        totalError += (y-(w*x+b))**2
        return totalError / float(len(points))

梯度下降法

梯度下降法的公式:

w ′ = w − l r ∇ l o s s ∇ w , b ′ = b − l r ∇ l o s s ∇ b w^{'}=w-lr { \frac{\nabla loss}{\nabla w} } , b^{'}=b-lr { \frac{\nabla loss}{\nabla b} } w=wlrwloss,b=blrbloss

定义的梯度下降法函数如下:

def step_gradient(b_current,w_current,points,learningRate):
    b_gradient = 0
    w_gradient = 0
    N = float(len(points))
    for i in range(0,len(points)):
        x=points[i,0]
        y=points[i,1]
        b_gradient+=-(2/N)*(y-((w_current*x)+b_current))
        w_gradient+=-(2/N)*x*(y-((w_current*x)+b_current))
    new_b = b_current - (learningRate * b_gradient)
    new_w = w_current - (learningRate * w_gradient)
    return [new_b, new_w]

循环迭代

def gradient_descent_runner(points,starting_b,starting_w,learning_rate,num_iterations):
    b = starting_b
    w = starting_w
    for i in range(num_iterations):
        b,w = step_gradient(b,w,np.array(points),learning_rate)
    return [b,w]

总运行函数

def run():
    points = np.genfromtxt("data.csv",delimiter=",")
    learning_rate = 0.0001
    initial_b = 0
    initial_w = 0
    num_iterations = 1000
    print("starting gradient descent at b={0},w={1},error={2}"
          .format(initial_b,initial_w,compute_error_for_line_given_points(initial_b,initial_w,points))
          )
    print("Running ...")
    [b,w]=gradient_descent_runner(points,initial_b,initial_w,learning_rate,num_iterations)
    print("After {0} iterations b={1},w={2},error={3}".format(num_iterations,b,w,
                                                              compute_error_for_line_given_points(b,w,points)))

程序入口

if __name__=='__main__':
    run()
``
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值