002_wz_ledr_pytorch深度学习实战_第三讲——梯度下降

一、目的

模拟批量梯度下降算法,计算在x_data、y_data数据集下, y = ω x y={\omega}x y=ωx模型找到合适的 ω \omega ω的值

二、编程

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = 1

# 正向传播
def forward(x):
    return x * w


# 计算损失
def cost(xs, ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost/len(xs)


# 反向传播
def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)


# 开始训练
mse_list = []
for epoch in range(100):
    # 计算成本
    cost_val = cost(x_data, y_data)
    # 计算梯度
    grad_val = gradient(x_data, y_data)
    # 更新参数
    w = w - 0.01 * grad_val
    mse_list.append(cost_val)
    print("epoch=", epoch, "cost_val=", cost_val, "w=", w)

# 预测
print("x=4, y=", forward(4))

# 绘图
plt.plot(mse_list)
plt.xlabel("epoch")
plt.ylabel("cost")
plt.show()
epoch= 0 cost_val= 4.666666666666667 w= 1.0933333333333333
epoch= 1 cost_val= 3.8362074074074086 w= 1.1779555555555554
epoch= 2 cost_val= 3.1535329869958857 w= 1.2546797037037036
epoch= 3 cost_val= 2.592344272332262 w= 1.3242429313580246
...
epoch= 97 cost_val= 2.593287985380858e-08 w= 1.9999324119941766
epoch= 98 cost_val= 2.131797981222471e-08 w= 1.9999387202080534
epoch= 99 cost_val= 1.752432687141379e-08 w= 1.9999444396553017
x=4, y= 7.999777758621207

在这里插入图片描述
为了解决在机器学习过程中在遇到“鞍点”(即总体所有点的梯度和为0,导致w=w-0.01*0,w不会改变)而导致不能继续进行的问题。可以采用随机梯度下降,即随机的取一组(x,y)的梯度,作为梯度下降的依据,而不用总体所有点的梯度和,作为梯度下降的依据。实质是使用“噪点”去推动梯度下降。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = 1

# 正向传播
def forward(x):
    return x * w


# 计算损失
def loss(x, y):
    cost = 0
    y_pred = forward(x)
    cost += (y_pred - y) ** 2
    return cost


# 反向传播
def gradient(x, y):
    grad = 0
    grad += 2 * x * (x * w - y)
    return grad


# 开始训练
mse_list = []
for epoch in range(100):
    # 每次使用一个样本进行训练
    for x_val, y_val in zip(x_data, y_data):
        # 计算成本
        cost_val = loss(x_val, y_val)
        # 计算梯度
        grad_val = gradient(x_val, y_val)
        # 更新参数
        w = w - 0.01 * grad_val
    mse_list.append(cost_val)
    print("epoch=", epoch, "cost_val=", cost_val, "w=", w)

# 预测
print("x=4, y=", forward(4))

# 绘图
plt.plot(mse_list)
plt.xlabel("epoch")
plt.ylabel("cost")
plt.show()
epoch= 0 cost_val= 7.315943039999998 w= 1.260688
epoch= 1 cost_val= 3.9987644858206908 w= 1.453417766656
epoch= 2 cost_val= 2.1856536232765476 w= 1.5959051959019805
epoch= 3 cost_val= 1.1946394387269013 w= 1.701247862192685
...
epoch= 97 cost_val= 2.6081713678869703e-25 w= 1.9999999999998603
epoch= 98 cost_val= 1.4248800100554526e-25 w= 1.9999999999998967
epoch= 99 cost_val= 7.82747233205549e-26 w= 1.9999999999999236
x=4, y= 7.9999999999996945

在这里插入图片描述

三、参考

pytorch深度学习实践
PyTorch 深度学习实践 第3讲
PyTorch学习(二)–梯度下降

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
Configure pins as * Analog * Input * Output * EVENT_OUT * EXTI */ static void MX_GPIO_Init(void) { GPIO_InitTypeDef GPIO_InitStruct; /* GPIO Ports Clock Enable */ //__HAL_RCC_GPIOH_CLK_ENABLE(); __HAL_RCC_GPIOC_CLK_ENABLE(); //__HAL_RCC_GPIOA_CLK_ENABLE(); __HAL_RCC_GPIOD_CLK_ENABLE(); __HAL_RCC_GPIOB_CLK_ENABLE(); /*Configure GPIO pin Output Level */ HAL_GPIO_WritePin(LEDR_OUT_PD3_GPIO_Port, LEDR_OUT_PD3_Pin, GPIO_PIN_SET); /*Configure GPIO pin Output Level */ //HAL_GPIO_WritePin(GPIOB, RS485_RE_OUT_PB8_Pin|RS485_SE_OUT_PB9_Pin, GPIO_PIN_RESET); /*Configure GPIO pin : LEDR_OUT_PD3_Pin */ GPIO_InitStruct.Pin = LEDR_OUT_PD3_Pin; GPIO_InitStruct.Mode = GPIO_MODE_OUTPUT_PP; GPIO_InitStruct.Pull = GPIO_PULLUP; GPIO_InitStruct.Speed = GPIO_SPEED_FREQ_VERY_HIGH; HAL_GPIO_Init(LEDR_OUT_PD3_GPIO_Port, &GPIO_InitStruct); /*Configure GPIO pins : RS485_RE_OUT_PB8_Pin RS485_SE_OUT_PB9_Pin */ GPIO_InitStruct.Pin = RS485_RE_OUT_PB8_Pin|RS485_SE_OUT_PB9_Pin; GPIO_InitStruct.Mode = GPIO_MODE_OUTPUT_PP; GPIO_InitStruct.Pull = GPIO_PULLUP; GPIO_InitStruct.Speed = GPIO_SPEED_FREQ_VERY_HIGH; HAL_GPIO_Init(GPIOB, &GPIO_InitStruct); } /* USER CODE BEGIN 4 */ /* USER CODE END 4 */ /** * @brief This function is executed in case of error occurrence. * @param file: The file name as string. * @param line: The line in file as a number. * @retval None */ void _Error_Handler(char *file, int line) { /* USER CODE BEGIN Error_Handler_Debug */ /* User can add his own implementation to report the HAL error return state */ while(1) { } /* USER CODE END Error_Handler_Debug */
07-14
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值