线性回归

为了深入理解Gradient Descent算法,写了如下代码。

在y = 2x直线上生成随机高斯噪声


# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 15:27:51 2017

@author: liuxy
"""

import numpy as np
import matplotlib.pyplot as plt

def gen_data(size):
    x = np.arange(0, size, 1)
    e = np.random.normal(0, 3, size)
    y = 2*x + e
    return [x, y]


def compute_gradient_full(data,  w):
    X = data[0]
    Y = data[1]
    N = len(X)
    g = np.sum(2*X*(X*w - Y))/N      
    return g
    
def compute_gradient_SGD(data,  w):
    X = data[0]
    Y = data[1]
    
    idx = np.random.randint(0, len(X)-1)
    d = X[idx]
    t = Y[idx]
   
    g = 2*d*(d*w - t) 
    return g

def compute_gradient_miniBatch(data,  w):
    X = data[0]
    Y = data[1]
    
    N = 16
    X_b = []
    Y_b = []
    for i in range(N):
        idx = np.random.randint(0, len(X)-1)
        X_b.append(X[idx])
        Y_b.append(Y[idx])        
    X_ba = np.array(X_b)
    Y_ba = np.array(Y_b)
    g = np.sum(2*X_ba*(X_ba*w - Y_ba))/N      
    return g


def Optimizer(data, w, learning_rate, num_iterator, method, Wts):
    for i in range(num_iterator):
        g = 0
        if ('full' == method):
            g = compute_gradient_full(data, w)   
        if ('mini' == method):
            g = compute_gradient_miniBatch(data, w)   
        if ('sgd' == method):
            g = compute_gradient_SGD(data, w)   
            
        w = w - learning_rate * g
        Wts.append(w)
   
        
data = gen_data(100)
#plt.scatter(data[0], data[1])


lr = 0.000020
w = 6
num = 100

Weights_full = []
Weights_mini = []
Weights_sgd = []

Weights_full.append(w)
Weights_mini.append(w)
Weights_sgd.append(w)
Optimizer(data, w, lr, num, 'full', Weights_full)
Optimizer(data, w, lr, num, 'mini', Weights_mini)
Optimizer(data, w, lr, num, 'sgd', Weights_sgd)
plt.plot(np.arange(0,num+1), Weights_full)  
plt.plot(np.arange(0,num+1), Weights_mini) 
plt.plot(np.arange(0,num+1), Weights_sgd) 




    




权重变化, full, mini batch, sgd


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
SQLAlchemy 是一个 SQL 工具包和对象关系映射(ORM)库,用于 Python 编程语言。它提供了一个高级的 SQL 工具和对象关系映射工具,允许开发者以 Python 类和对象的形式操作数据库,而无需编写大量的 SQL 语句。SQLAlchemy 建立在 DBAPI 之上,支持多种数据库后端,如 SQLite, MySQL, PostgreSQL 等。 SQLAlchemy 的核心功能: 对象关系映射(ORM): SQLAlchemy 允许开发者使用 Python 类来表示数据库表,使用类的实例表示表中的行。 开发者可以定义类之间的关系(如一对多、多对多),SQLAlchemy 会自动处理这些关系在数据库中的映射。 通过 ORM,开发者可以像操作 Python 对象一样操作数据库,这大大简化了数据库操作的复杂性。 表达式语言: SQLAlchemy 提供了一个丰富的 SQL 表达式语言,允许开发者以 Python 表达式的方式编写复杂的 SQL 查询。 表达式语言提供了对 SQL 语句的灵活控制,同时保持了代码的可读性和可维护性。 数据库引擎和连接池: SQLAlchemy 支持多种数据库后端,并且为每种后端提供了对应的数据库引擎。 它还提供了连接池管理功能,以优化数据库连接的创建、使用和释放。 会话管理: SQLAlchemy 使用会话(Session)来管理对象的持久化状态。 会话提供了一个工作单元(unit of work)和身份映射(identity map)的概念,使得对象的状态管理和查询更加高效。 事件系统: SQLAlchemy 提供了一个事件系统,允许开发者在 ORM 的各个生命周期阶段插入自定义的钩子函数。 这使得开发者可以在对象加载、修改、删除等操作时执行额外的逻辑。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值