深度学习梯度下降简易演示

简易参数更新过程
在这里插入图片描述

在这里插入图片描述

#---------------------------------------------#
#  简易的利用SDG(随机梯度下降)拟合曲线动画过程
#---------------------------------------------#
import random 
import matplotlib.pyplot as plt

_x=[i/100 for i in range(100)]
_y=[3*e+4+random.random() for e in _x]

w=random.random()
b=random.random()
#--------------------------------------------------------------------#
#  开启交互模式,图像将不需要显式调用 plt.show(),在plt.plot()就显示
#--------------------------------------------------------------------#
plt.ion()
#-----------------#
#  10个epoch
#-----------------#
for i in range(10):
    for x,y in zip(_x,_y):
        z=x*w+b
        o=z-y
        loss=o**2
        dw=2*o*x
        db=2*o
        #--------------#
        #  更新参数
        #--------------#
        w=w-0.1*dw
        b=b-0.1*dw
        print(w,b,loss)
        #--------------#
        #  清除当前轴
        #--------------#
        plt.cla()
        plt.plot(_x,_y,'.')
        v=[w*e+b for e in _x]
        plt.plot(_x,v)
        #---------------#
        #  暂停0.01s
        #---------------#
        plt.pause(0.01)
#-----------------#
#  关闭交互模式
#-----------------#
plt.ioff()
#-------------------#
#  显示最后一张图片
#-------------------#
plt.show()

加入激活函数relu

#---------------------------------------------#
#  简易的利用SDG(随机梯度下降)拟合曲线动画过程
#---------------------------------------------#
import random 
import matplotlib.pyplot as plt

_x=[i/100 for i in range(100)]
_y=[3*e+4+random.random() for e in _x]

w=random.random()
b=random.random()

def relu(x):
    if(x>0):
        return x
    return 0
 
def relu_d(x): # (relu导数)
    if(x>0):
        return 1
    return 0

#--------------------------------------------------------------------#
#  开启交互模式,图像将不需要显式调用 plt.show(),在plt.plot()就显示
#--------------------------------------------------------------------#
plt.ion()
#-----------------#
#  10个epoch
#-----------------#
for i in range(10):
    for x,y in zip(_x,_y):
        z=relu(x*w+b)
        o=z-y
        loss=o**2
        dw=2*o*x*relu_d(x*w+b)
        db=2*o*relu_d(x*w+b)
        #--------------#
        #  更新参数
        #--------------#
        w=w-0.1*dw
        b=b-0.1*dw
        print(w,b,loss)
        #--------------#
        #  清除当前轴
        #--------------#
        plt.cla()
        plt.plot(_x,_y,'.')
        v=[relu(w*e+b) for e in _x]
        plt.plot(_x,v)
        #---------------#
        #  暂停0.01s
        #---------------#
        plt.pause(0.01)
#-----------------#
#  关闭交互模式
#-----------------#
plt.ioff()
#-------------------#
#  显示最后一张图片
#-------------------#
plt.show()

参考:https://www.bilibili.com/video/BV15i4y147pf/?spm_id_from=333.337.search-card.all.click&vd_source=9c63f89b714e96dfc638093fbe9f907d

  • 8
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值