最速下降法python_Python-梯度下降法(最速下降法)求解多元函数

import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D

import numpy as np

def Fun(x,y): #原函数

return x-y+2*x*x+2*x*y+y*y #方程为z=x1^2+2x2^2-4x1-2x1x2

def PxFun(x,y): #x偏导

return 1+4*x+2*y

def PyFun(x,y): #y偏导

return -1+2*x+2*y

#初始化

fig = plt.figure() #figure对象

ax = Axes3D(fig) #Axes3D对象

X,Y = np.mgrid[-2:2:40j,-2:2:40j] #取样并作满射联合

Z = Fun(X,Y) #取样点Z坐标打表

ax.plot_surface(X,Y,Z,rstride=1,cstride=1,cmap="rainbow")

ax.set_xlabel('x')

ax.set_ylabel('y')

ax.set_zlabel('z')

#梯度下降

step = 0.0008 #下降系数

x = 0

y = 0 #初始选取一个点

tag_x = [x]tag_y = [y]tag_z = [Fun(x,y)] #三个坐标分别打入表中,该表用于绘制点

new_x = x

new_y = y

Over = False

while Over == False:

new_x -= step*PxFun(x,y)

new_y -= step*PyFun(x,y) #分别作梯度下降

if Fun(x,y) - Fun(new_x,new_y) < 7e-9: #精度

Over = True

x = new_x

y = new_y #更新旧点

tag_x.append(x)

tag_y.append(y)

tag_z.append(Fun(x,y)) #新点三个坐标打入表中

#绘制点/输出坐标

ax.plot(tag_x,tag_y,tag_z,'r.')

plt.title('(x,y)~('+str(x)+","+str(y)+')')

plt.show()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值