入门补充 最速梯度下降

名字很叼有木有?

实现了函数最小值的计算,具体说是高中数学知识,完成了求导和梯度下降,找出在某一范围内的最小值,注意调节步长(贼坑)

"""
最速下降法
Rosenbrock函数
函数 f(x)=100*(x(2)-x(1).^2).^2+(1-x(1)).^2
梯度 g(x)=(-400*(x(2)-x(1)^2)*x(1)-2*(1-x(1)),200*(x(2)-x(1)^2))^(T)
"""

import numpy as np
import matplotlib.pyplot as plt
import random

def rosenbrock(x):
    return 100 * (x[1] - x[0] ** 2) ** 2 + (1 - x[0]) ** 2

def jacobian(x):
    return np.array([-400 * x[0] * (x[1] - x[0] ** 2) - 2 * (1 - x[0]), 200 * (x[1] - x[0] ** 2)])

'''
X1 = np.arange(-1.5, 1.5 + 0.05, 0.05)
X2 = np.arange(-3.5, 2 + 0.05, 0.05)
[x1, x2] = np.meshgrid(X1, X2)
f = 100 * (x2 - x1 ** 2) ** 2 + (1 - x1) ** 2  # 给定的函数
#plt.contour(x1, x2, f, 20)  # 画出函数的20条轮廓线
'''

def steepest(x0):
    print('初始点为:')
    print(x0, '\n')
    imax = 20000
    W = np.zeros((2, imax))
    W[:, 0] = x0
    i = 1
    x = x0
    grad = jacobian(x)
    delta = sum(grad ** 2)  # 初始误差
    while i < imax and delta > 10 ** (-5):
        p = -jacobian(x)
        x0 = x
        alpha = goldsteinsearch(rosenbrock, jacobian, p, x, 1, 0.1, 2)
        x = x + alpha * p
        W[:, i] = x
        grad = jacobian(x)
        delta = sum(grad ** 2)
        i = i + 1
    print("迭代次数为:", i)
    print("近似最优解为:")
    print(x, '\n')
    W = W[:, 0:i]  # 记录迭代点
    return W

'''
线性搜索子函数
'''
def goldsteinsearch(f, df, d, x, alpham, rho, t):
    flag = 0
    a = 0
    b = alpham
    fk = f(x)
    gk = df(x)
    phi0 = fk
    dphi0 = np.dot(gk, d)
    alpha = b * random.uniform(0, 1)
    while (flag == 0):
        newfk = f(x + alpha * d)
        phi = newfk
        if (phi - phi0 <= rho * alpha * dphi0):
            if (phi - phi0 >= (1 - rho) * alpha * dphi0):
                flag = 1
            else:
                a = alpha
                b = b
                if (b < alpham):
                    alpha = (a + b) / 2
                else:
                    alpha = t * alpha
        else:
            a = a
            b = alpha
            alpha = (a + b) / 2
    return alpha


x0 = np.array([-1.2, 5])#检测范围
W = steepest(x0)
#plt.plot(W[0, :], W[1, :], 'g*', W[0, :], W[1, :])  # 画出迭代点收敛的轨迹
#plt.show()


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值