随机梯度算法应用动画场景

15 篇文章 0 订阅
3 篇文章 0 订阅

https://github.com/lilipads/gradient_descent_viz


#include <math.h>
#include <iostream>

namespace Function {
    enum FunctionName {
        local_minimum,  // 许多小坑
        global_minimum,    // 大坑
        saddle_point,   // 起伏山路
        ecliptic_bowl, // 起伏大坑
        hills,  // 大坑 + 小坑 + 山谷
        plateau // 大坑 + 波纹
    };
}

class Point { 
public:
    Point() : x(0.), z(0.) {}
    Point(double x1, double z1) : x(x1), z(z1) {}
    double x; 
    double z;
};

const float kBallRadiusPerGraph = 24.63;

class GradientDescent {
public:
    GradientDescent();
    virtual ~GradientDescent() {}

    double learning_rate = 0.001;
    static Function::FunctionName function_name;

    // simple getters and setters
    Point position() { return p; }
    void setStartingPosition(double x, double z) { starting_p.x = x; starting_p.z = z; }
    bool isConverged() { return is_converged; };
    double gradX() { return grad.x; };
    double gradZ() { return grad.z; };
    Point gradPosition() { return grad; };
    Point delta() { return m_delta; }


    // core methods
    static double f(double x, double z);
    Point takeGradientStep();
    void resetPositionAndComputeGradient();

protected:
    Point p; // current position
    Point starting_p; // starting position
    Point m_delta; // movement in each direction after a gradient step
    Point grad; // gradient at the current position
    bool is_converged = false;

    void setPositionAndComputeGradient(double x, double z);
    void computeGradient();
    virtual void updateGradientDelta() = 0;
    virtual void resetState() {}
};


class VanillaGradientDescent : public GradientDescent {
public:
    VanillaGradientDescent() {}

protected:
    void updateGradientDelta();
};


class Momentum : public GradientDescent {
public:
    Momentum() {}

    double decay_rate = 0.8;

protected:
    void updateGradientDelta();
};

class AdaGrad : public GradientDescent {
public:
    AdaGrad() {}
    Point gradSumOfSquared() { return grad_sum_of_squared; }

protected:
    void updateGradientDelta();
    void resetState();

private:
    Point grad_sum_of_squared;
};

class RMSProp : public GradientDescent {
public:
    RMSProp() {}

    double decay_rate = 0.99;
    Point decayedGradSumOfSquared() { return decayed_grad_sum_of_squared; }

protected:
    void updateGradientDelta();
    void resetState();

private:
    Point decayed_grad_sum_of_squared;
};

class Adam : public GradientDescent {
public:
    Adam() {}

    double beta1 = 0.9;
    double beta2 = 0.999;
    Point decayedGradSum() { return decayed_grad_sum; }
    Point decayedGradSumOfSquared() { return decayed_grad_sum_of_squared; }

protected:
    void updateGradientDelta();
    void resetState();

private:
    Point decayed_grad_sum;
    Point decayed_grad_sum_of_squared;
};


const double kDivisionEpsilon = 1e-12;
const double kFiniteDiffEpsilon = 1e-12;
const double kConvergenceEpsilon = 1e-2;

Function::FunctionName GradientDescent::function_name = Function::local_minimum;


GradientDescent::GradientDescent()
{
    resetPositionAndComputeGradient();
}


double GradientDescent::f(double x, double z) {
    switch (function_name) {
    case Function::local_minimum: {
        z *= 1.4;
        return -2 * exp(-((x - 1) * (x - 1) + z * z) / .2) -
            6. * exp(-((x + 1) * (x + 1) + z * z) / .2) +
            x * x + z * z;
    }
    case Function::global_minimum: {
        return x * x + z * z;
    }
    case Function::saddle_point: {
        return sin(x) + z * z;
    }
    case Function::ecliptic_bowl: {
        x /= 2.;
        z /= 2.;
        return -exp(-(x * x + 5 * z * z)) + x * x + 0.5 * z * z;
    }
    case Function::hills: {
        z *= 1.4;
        return  2 * exp(-((x - 1) * (x - 1) + z * z) / .2) +
            6. * exp(-((x + 1) * (x + 1) + z * z) / .2) -
            2 * exp(-((x - 1) * (x - 1) + (z + 1) * (z + 1)) / .2) +
            x * x + z * z;
    }
    case Function::plateau: {
        x *= 10;
        z *= 10;
        double r = sqrt(z * z + x * x) + 0.01;
        return -sin(r) / r + 0.01 * r * r;
    }
    }
    return 0.;
}


void GradientDescent::computeGradient() {
    // use finite difference method
    grad.x = (f(p.x + kFiniteDiffEpsilon, p.z) -
        f(p.x - kFiniteDiffEpsilon, p.z)) / (2 * kFiniteDiffEpsilon);

    grad.z = (f(p.x, p.z + kFiniteDiffEpsilon) -
        f(p.x, p.z - kFiniteDiffEpsilon)) / (2 * kFiniteDiffEpsilon);
}

void GradientDescent::resetPositionAndComputeGradient() {
    is_converged = false;
    m_delta = Point(0, 0);
    resetState();
    setPositionAndComputeGradient(starting_p.x, starting_p.z);
}


void GradientDescent::setPositionAndComputeGradient(double x, double z) {
    /* set position and dirty gradient */

    p.x = x;
    p.z = z;
    computeGradient();
}

Point GradientDescent::takeGradientStep() {
    /* take a gradient step. return the new position
     * side effects:
     * - update delta to the step just taken
     * - update position to new position.
     * - update grad to gradient of the new position
     */

    if (abs(gradX()) < kConvergenceEpsilon &&
        abs(gradZ()) < kConvergenceEpsilon) {
        is_converged = true;
    }
    if (is_converged) return p;

    updateGradientDelta();
    setPositionAndComputeGradient(p.x + m_delta.x, p.z + m_delta.z);
    return p;
}

void VanillaGradientDescent::updateGradientDelta() {
    m_delta.x = -learning_rate * grad.x;
    m_delta.z = -learning_rate * grad.z;
}

void Momentum::updateGradientDelta() {
    /* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Momentum */

    m_delta.x = decay_rate * m_delta.x - learning_rate * grad.x;
    m_delta.z = decay_rate * m_delta.z - learning_rate * grad.z;
}


void AdaGrad::updateGradientDelta() {
    /* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#AdaGrad */

    grad_sum_of_squared.x += pow(grad.x, 2);
    grad_sum_of_squared.z += pow(grad.z, 2);
    m_delta.x = -learning_rate * grad.x / (sqrt(grad_sum_of_squared.x) + kDivisionEpsilon);
    m_delta.z = -learning_rate * grad.z / (sqrt(grad_sum_of_squared.z) + kDivisionEpsilon);
}


void AdaGrad::resetState() {
    grad_sum_of_squared = Point(0, 0);
}


void RMSProp::updateGradientDelta() {
    /* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#RMSProp */

    decayed_grad_sum_of_squared.x *= decay_rate;
    decayed_grad_sum_of_squared.x += (1 - decay_rate) * pow(grad.x, 2);
    decayed_grad_sum_of_squared.z *= decay_rate;
    decayed_grad_sum_of_squared.z += (1 - decay_rate) * pow(grad.z, 2);
    m_delta.x = -learning_rate * grad.x / (sqrt(decayed_grad_sum_of_squared.x) + kDivisionEpsilon);
    m_delta.z = -learning_rate * grad.z / (sqrt(decayed_grad_sum_of_squared.z) + kDivisionEpsilon);
}


void RMSProp::resetState() {
    decayed_grad_sum_of_squared = Point(0, 0);
}


void Adam::updateGradientDelta() {
    /* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam */

    // first moment (momentum)
    decayed_grad_sum.x *= beta1;
    decayed_grad_sum.x += (1 - beta1) * grad.x;
    decayed_grad_sum.z *= beta1;
    decayed_grad_sum.z += (1 - beta1) * grad.z;
    // second moment (rmsprop)
    decayed_grad_sum_of_squared.x *= beta2;
    decayed_grad_sum_of_squared.x += (1 - beta2) * pow(grad.x, 2);
    decayed_grad_sum_of_squared.z *= beta2;
    decayed_grad_sum_of_squared.z += (1 - beta2) * pow(grad.z, 2);

    m_delta.x = -learning_rate * decayed_grad_sum.x /
        (sqrt(decayed_grad_sum_of_squared.x) + kDivisionEpsilon);
    m_delta.z = -learning_rate * decayed_grad_sum.z /
        (sqrt(decayed_grad_sum_of_squared.z) + kDivisionEpsilon);
}


void Adam::resetState() {
    decayed_grad_sum_of_squared = Point(0, 0);
    decayed_grad_sum = Point(0, 0);
}


void print(std::string title, const Point& pt) {
    double scrollOffset = 2000;
    double yOffset = scrollOffset / kBallRadiusPerGraph;
    yOffset = 0;
    std::cout << title << "(" << pt.x << ", " << GradientDescent::f(pt.x, pt.z) + yOffset << ", " << pt.z << ")" << std::endl;
    
}

void test() {
    // 可用于滑动后缓动慢停止
    VanillaGradientDescent  vanillaGradientDescent;
    vanillaGradientDescent.learning_rate = 0.01; //速度
    vanillaGradientDescent.function_name = Function::global_minimum; // 场景
    vanillaGradientDescent.setStartingPosition(100, 100);  // 顶部起始位置
    vanillaGradientDescent.resetPositionAndComputeGradient();   // 更新梯度和当前位置
    print("pt", vanillaGradientDescent.position());
    print("gradient", vanillaGradientDescent.gradPosition());
    print("delta", vanillaGradientDescent.delta());
    for (size_t i = 0; i < 60; i++)
    {   // 60次
        vanillaGradientDescent.takeGradientStep();
        print("pt", vanillaGradientDescent.position());
        print("gradient", vanillaGradientDescent.gradPosition());
        print("delta", vanillaGradientDescent.delta());
    }

    // 可用于滑动后反弹多次逐渐停止
    Momentum  momentum;
    momentum.learning_rate = 0.01; //速度
    momentum.decay_rate = 0.9;  // 摩擦系数, 1表示理想情况,无摩擦
    momentum.setStartingPosition(100, 100);  // 顶部起始位置
    momentum.resetPositionAndComputeGradient();   // 更新梯度和当前位置
    print("pt", momentum.position());
    print("gradient", momentum.gradPosition());
    print("delta", momentum.delta());
    for (size_t i = 0; i < 60; i++)
    {   // 200次
        momentum.takeGradientStep();
        print("pt", momentum.position());
        print("gradient", momentum.gradPosition());
        print("delta", momentum.delta());
    }
}


创作不易,小小的支持一下吧!

  • 27
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码力码力我爱你

创作不易,小小的支持一下吧!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值