一、牛顿法概述
除了前面说的梯度下降法,牛顿法也是机器学习中用的比较多的一种优化算法。牛顿法的基本思想是利用迭代点处的一阶导数(梯度)和二阶导数(
Hessen
矩阵)对目标函数进行二次函数近似,然后把二次模型的极小点作为新的迭代点,并不断重复这一过程,直至求得满足精度的近似极小值。牛顿法的速度相当快,而且能高度逼近最优值。牛顿法分为基本的牛顿法和全局牛顿法。
二、基本牛顿法
1、基本牛顿法的原理
基本牛顿法是一种是用导数的算法,它每一步的迭代方向都是沿着当前点函数值下降的方向。
我们主要集中讨论在一维的情形,对于一个需要求解的优化函数,求函数的极值的问题可以转化为求导函数。对函数进行泰勒展开到二阶,得到
对上式求导并令其为0,则为
即得到
这就是牛顿法的更新公式。
2、基本牛顿法的流程
牛顿法最突出的优点是收敛速度快,具有局部二阶收敛性,但是,基本牛顿法初始点需要足够“靠近”极小点,否则,有可能导致算法不收敛。这样就引入了全局牛顿法。
1、全局牛顿法的流程
全局牛顿法是基于
Armijo
的搜索,满足
Armijo
准则:
给定,,令步长因子,其中是满足下列不等式的最小非负整数:
四、算法实现
实验部分使用Java实现,需要优化的函数,最小值为。
1、基本牛顿法Java实现
- package org.algorithm.newtonmethod;
- /**
- * Newton法
- *
- * @author dell
- *
- */
- public class NewtonMethod {
- private double originalX;// 初始点
- private double e;// 误差阈值
- private double maxCycle;// 最大循环次数
- /**
- * 构造方法
- *
- * @param originalX初始值
- * @param e误差阈值
- * @param maxCycle最大循环次数
- */
- public NewtonMethod(double originalX, double e, double maxCycle) {
- this.setOriginalX(originalX);
- this.setE(e);
- this.setMaxCycle(maxCycle);
- }
- // 一系列get和set方法
- public double getOriginalX() {
- return originalX;
- }
- public void setOriginalX(double originalX) {
- this.originalX = originalX;
- }
- public double getE() {
- return e;
- }
- public void setE(double e) {
- this.e = e;
- }
- public double getMaxCycle() {
- return maxCycle;
- }
- public void setMaxCycle(double maxCycle) {
- this.maxCycle = maxCycle;
- }
- /**
- * 原始函数
- *
- * @param x变量
- * @return 原始函数的值
- */
- public double getOriginal(double x) {
- return x * x - 3 * x + 2;
- }
- /**
- * 一次导函数
- *
- * @param x变量
- * @return 一次导函数的值
- */
- public double getOneDerivative(double x) {
- return 2 * x - 3;
- }
- /**
- * 二次导函数
- *
- * @param x变量
- * @return 二次导函数的值
- */
- public double getTwoDerivative(double x) {
- return 2;
- }
- /**
- * 利用牛顿法求解
- *
- * @return
- */
- public double getNewtonMin() {
- double x = this.getOriginalX();
- double y = 0;
- double k = 1;
- // 更新公式
- while (k <= this.getMaxCycle()) {
- y = this.getOriginal(x);
- double one = this.getOneDerivative(x);
- if (Math.abs(one) <= e) {
- break;
- }
- double two = this.getTwoDerivative(x);
- x = x - one / two;
- k++;
- }
- return y;
- }
- }
2、全局牛顿法Java实现
- package org.algorithm.newtonmethod;
- /**
- * 全局牛顿法
- *
- * @author dell
- *
- */
- public class GlobalNewtonMethod {
- private double originalX;
- private double delta;
- private double sigma;
- private double e;
- private double maxCycle;
- public GlobalNewtonMethod(double originalX, double delta, double sigma,
- double e, double maxCycle) {
- this.setOriginalX(originalX);
- this.setDelta(delta);
- this.setSigma(sigma);
- this.setE(e);
- this.setMaxCycle(maxCycle);
- }
- public double getOriginalX() {
- return originalX;
- }
- public void setOriginalX(double originalX) {
- this.originalX = originalX;
- }
- public double getDelta() {
- return delta;
- }
- public void setDelta(double delta) {
- this.delta = delta;
- }
- public double getSigma() {
- return sigma;
- }
- public void setSigma(double sigma) {
- this.sigma = sigma;
- }
- public double getE() {
- return e;
- }
- public void setE(double e) {
- this.e = e;
- }
- public double getMaxCycle() {
- return maxCycle;
- }
- public void setMaxCycle(double maxCycle) {
- this.maxCycle = maxCycle;
- }
- /**
- * 原始函数
- *
- * @param x变量
- * @return 原始函数的值
- */
- public double getOriginal(double x) {
- return x * x - 3 * x + 2;
- }
- /**
- * 一次导函数
- *
- * @param x变量
- * @return 一次导函数的值
- */
- public double getOneDerivative(double x) {
- return 2 * x - 3;
- }
- /**
- * 二次导函数
- *
- * @param x变量
- * @return 二次导函数的值
- */
- public double getTwoDerivative(double x) {
- return 2;
- }
- /**
- * 利用牛顿法求解
- *
- * @return
- */
- public double getGlobalNewtonMin() {
- double x = this.getOriginalX();
- double y = 0;
- double k = 1;
- // 更新公式
- while (k <= this.getMaxCycle()) {
- y = this.getOriginal(x);
- double one = this.getOneDerivative(x);
- if (Math.abs(one) <= e) {
- break;
- }
- double two = this.getTwoDerivative(x);
- double dk = -one / two;// 搜索的方向
- double m = 0;
- double mk = 0;
- while (m < 20) {
- double left = this.getOriginal(x + Math.pow(this.getDelta(), m)
- * dk);
- double right = this.getOriginal(x) + this.getSigma()
- * Math.pow(this.getDelta(), m)
- * this.getOneDerivative(x) * dk;
- if (left <= right) {
- mk = m;
- break;
- }
- m++;
- }
- x = x + Math.pow(this.getDelta(), mk)*dk;
- k++;
- }
- return y;
- }
- }
3、主函数
- package org.algorithm.newtonmethod;
- /**
- * 测试函数
- * @author dell
- *
- */
- public class TestNewton {
- public static void main(String args[]) {
- NewtonMethod newton = new NewtonMethod(0, 0.00001, 100);
- System.out.println("基本牛顿法求解:" + newton.getNewtonMin());
- GlobalNewtonMethod gNewton = new GlobalNewtonMethod(0, 0.55, 0.4,
- 0.00001, 100);
- System.out.println("全局牛顿法求解:" + gNewton.getGlobalNewtonMin());
- }
- }