Python:梯度下降实现之小例子

import matplotlib.pyplot as plt
import numpy as np

class GD(object):
    def __init__(self,seed=None, precision=1.E-6):
        self.seed = GD.get_seed(seed)    # 梯度下降算法的种子点
        self.prec = precision            # 梯度下降算法的精度
        self.path = list()               # 记录种子点的路径及相应的目标函数值
        self.solve()                     # 求解主体
        self.display()                   # 数据可视化

    def solve(self):
        x_curr = self.seed
        val_curr = GD.func(*x_curr)
        self.path.append((x_curr,val_curr))

        omega = 1
        while omega > self.prec:
            x_delta = omega * GD.get_grad(*x_curr)
            x_next = x_curr - x_delta
            val_next = GD.func(*x_next)

            if np.abs(val_next - val_curr) < self.prec:
                break
            if val_next < val_curr:
                x_curr = x_next
                val_curr = val_next
                omega *= 1.2
                self.path.append((x_curr,val_curr))
            else:
                omega *= 0.5

    def display(self):
        print("Iteration steps:{}".format(len(self.path)))
        print("Seed:({})".format(",".join(str(item) for item in self.path[0][0])))
        print("Solution:({})".format(",".join(str(item) for item in self.path[-1][0])))

        fig = plt.figure(figsize=(10,4))

        ax1 = plt.subplot(1,2,1)
        ax2 = plt.subplot(1,2,2)

        ax1.plot(np.array(range(len(self.path)))+1, np.array(list(item[1] for item in self.path)),"k.")
        ax1.plot(1, self.path[0][1], 'go', label='starting point')
        ax1.plot(len(self.path), self.path[-1][1],'r*',label='solution')
        ax1.set(xlabel='$iterCnt$', ylabel='$iterVal$')
        ax1.legend()

        x = np.linspace(-100, 100, 500)
        y = np.linspace(-100,100,500)
        x,y = np.meshgrid(x,y)
        z = GD.func(x,y)
        ax2.contour(x,y,z, levels=70)

        x2 = np.array(list(item[0][0] for item in self.path))
        y2 = np.array(list(item[0][1] for item in self.path))
        ax2.plot(x2, y2, 'ko', linewidth=2)
        ax2.plot(x2[0], y2[0], 'go', label='starting point')
        ax2.plot(x2[-1], y2[-1], 'r*', label='solution')

        ax2.set(xlabel='$x$', ylabel='$y$')
        ax2.legend()

        fig.tight_layout()


        plt.show()




    # 内部种子生成函数
    @staticmethod
    def get_seed(seed):
        if seed is not None:
            return np.array(seed)
        return np.random.uniform(-100,100,2)

    # 目标函数
    @staticmethod
    def func(x,y):
        return 5 * x ** 2 + 2 * y ** 2 + 3 * x -10 * y + 4

    # 目标函数的归一化梯度
    @staticmethod
    def get_grad(x,y):
        grad_ori = np.array([10 * x + 3, 4 * y -10])
        length = np.linalg.norm(grad_ori)
        if length == 0:
            return np.zeros(2)
        return grad_ori / length

if __name__ == '__main__':
    GD()

参考来源:https://home.cnblogs.com/u/xxhbdk/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

DeniuHe

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值