#---------------------------------------------#
# 简易的利用SDG(随机梯度下降)拟合曲线动画过程
#---------------------------------------------#
import random
import matplotlib.pyplot as plt
_x=[i/100 for i in range(100)]
_y=[3*e+4+random.random() for e in _x]
w=random.random()
b=random.random()
#--------------------------------------------------------------------#
# 开启交互模式,图像将不需要显式调用 plt.show(),在plt.plot()就显示
#--------------------------------------------------------------------#
plt.ion()
#-----------------#
# 10个epoch
#-----------------#
for i in range(10):
for x,y in zip(_x,_y):
z=x*w+b
o=z-y
loss=o**2
dw=2*o*x
db=2*o
#--------------#
# 更新参数
#--------------#
w=w-0.1*dw
b=b-0.1*dw
print(w,b,loss)
#--------------#
# 清除当前轴
#--------------#
plt.cla()
plt.plot(_x,_y,'.')
v=[w*e+b for e in _x]
plt.plot(_x,v)
#---------------#
# 暂停0.01s
#---------------#
plt.pause(0.01)
#-----------------#
# 关闭交互模式
#-----------------#
plt.ioff()
#-------------------#
# 显示最后一张图片
#-------------------#
plt.show()
加入激活函数relu
#---------------------------------------------#
# 简易的利用SDG(随机梯度下降)拟合曲线动画过程
#---------------------------------------------#
import random
import matplotlib.pyplot as plt
_x=[i/100 for i in range(100)]
_y=[3*e+4+random.random() for e in _x]
w=random.random()
b=random.random()
def relu(x):
if(x>0):
return x
return 0
def relu_d(x): # (relu导数)
if(x>0):
return 1
return 0
#--------------------------------------------------------------------#
# 开启交互模式,图像将不需要显式调用 plt.show(),在plt.plot()就显示
#--------------------------------------------------------------------#
plt.ion()
#-----------------#
# 10个epoch
#-----------------#
for i in range(10):
for x,y in zip(_x,_y):
z=relu(x*w+b)
o=z-y
loss=o**2
dw=2*o*x*relu_d(x*w+b)
db=2*o*relu_d(x*w+b)
#--------------#
# 更新参数
#--------------#
w=w-0.1*dw
b=b-0.1*dw
print(w,b,loss)
#--------------#
# 清除当前轴
#--------------#
plt.cla()
plt.plot(_x,_y,'.')
v=[relu(w*e+b) for e in _x]
plt.plot(_x,v)
#---------------#
# 暂停0.01s
#---------------#
plt.pause(0.01)
#-----------------#
# 关闭交互模式
#-----------------#
plt.ioff()
#-------------------#
# 显示最后一张图片
#-------------------#
plt.show()
参考:https://www.bilibili.com/video/BV15i4y147pf/?spm_id_from=333.337.search-card.all.click&vd_source=9c63f89b714e96dfc638093fbe9f907d