最速下降法
优点: 1. 简单、易于掌握
2. 整体收敛性好
缺点: 1. 下降速度一阶线性收敛
2. 搜索路径存在锯齿现象,越靠近解,收敛速度越慢。
from 实用优化算法.helper import * # 单独写一篇博客分享
from sympy import *
import matplotlib.pyplot as plt
# 定义符号
x1, x2, t = symbols('x1, x2, t')
def func(): # 自定义一个函数
return x1**2 + 2*x2**2
def main(X0, theta):
grad_vec = get_grad(X0,func())
print(grad_vec)
grad_length = get_len_grad(grad_vec) # 梯度向量的模长
k = 0
x = x0
data_x = [X0[0][0]]
data_y = [X0[1][0]]
while grad_length > theta: # 迭代的终止条件
k += 1
p = -1 * grad_vec
# 迭代
t = golden_search(0,1,X0,p,func())
print(k,X0)
X0 = X0 + t*p
grad_vec = get_grad(X0,func())
grad_length = get_len_grad(grad_vec)
data_x.append(X0[0][0])
data_y.append(X0[1][0])
print('迭代次数',k)
# 绘图
ax = plt.gca()
plt.plot(data_x, data_y) # x轴数据传入 y轴数据传入
x_major_locator = plt.MultipleLocator(1) # x轴坐标单位值
y_major_locator = plt.MultipleLocator(1) # y轴坐标单位值
ax.xaxis.set_major_locator(x_major_locator)
ax.yaxis.set_major_locator(y_major_locator)
plt.xlim(-10,10) #x 坐标轴的范围
plt.ylim(-5,5) #y 坐标轴的范围
plt.grid() # 图表按网格生成
plt.show()
if __name__ == '__main__':
# 给定初始迭代点和阈值
x0 = [[4],[4]]
print('起始点',x0)
main(np.array(x0), 0.00001)