3.4 随机梯度下降【斯坦福21秋季:实用机器学习中文版】代码实现

文章目录

题目

3.4 随机梯度下降【斯坦福21秋季:实用机器学习中文版】代码实现

代码

'''
Description: SGD代码实现
Autor: 365JHWZGo
Date: 2022-03-20 12:10:30
LastEditors: 365JHWZGo
LastEditTime: 2022-03-20 17:46:25
'''
import random
import torch
import matplotlib.pyplot as plt
EPOCH = 40
BATCH_SIZE = 32
m = 3
NUM = 1000
LR = 0.03

# 创造数据
def create_data(w, b, num_examples):
    # w.shape [m, 1]
    # b.shape [1]

    # X.shape [num_examples, m]
    X = torch.normal(0, 1, (num_examples, len(w)))

    # y = X*w = [num_examples, 1]
    y = torch.matmul(X, w) + b

    # Y.shape [num_examples, 1]
    Y = torch.normal(0, 0.01, y.shape)+y
    return X, Y

# batch数据截取
def data_iteration(batch_size, features, labels):
    # features.shape [num_examples, m]
    # labels.shape [num_examples, 1]

    num_examples = len(features)
    indices_num = list(range(num_examples))
    random.shuffle(indices_num)
    for i in range(0, num_examples, batch_size):
        data_indices = torch.tensor(
            # 当数据不足切片时,取到num_examples
            indices_num[i:min(i+batch_size, num_examples)]
        )
        # features.shape [batch_size, m]
        # labels.shape [batch_size, 1]
        yield features[data_indices], labels[data_indices]

# SGD函数
def SGD(y_acc, y_pre):
    loss = ((y_acc - y_pre)**2/2).sum()
    return loss

# 函数预测
def linear_predict(x, w, b):
    y = torch.matmul(x, w)+b
    return y


'''
实际上的w和b
'''
# w_acc.shape [m, 1]
w_acc = torch.tensor([[-2.0], [2.5], [-1.9]])
# b_acc.shape [1]
b_acc = torch.tensor([8.1])

# 生成有噪音的数据
features, labels = create_data(w_acc, b_acc, NUM)

'''
画图
plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1)
plt.show()
'''

'''
预测时初始化w和b
'''
# w.shape [m, 1] 可导
w = torch.normal(0, 0.01, (m,1), requires_grad=True)
# b.shape [1] 可导
b = torch.zeros(1, requires_grad=True)

# 主函数入口
if __name__ == '__main__':
    
    for epoch in range(EPOCH):
        # features.shape [batch_size, m]
        # labels.shape [batch_size, 1]
        for f, l in data_iteration(BATCH_SIZE, features, labels):

            # 进行y值预测 y_pre.shape [batch_size, 1]
            y_pre = linear_predict(f, w, b)

            # SGD梯度下降
            loss = SGD(l, y_pre)

            # 求导
            loss.backward()

            with torch.no_grad():
                for param in [w, b]:
                    param -= LR * param.grad / BATCH_SIZE
                    param.grad.zero_()

        # test
        with torch.no_grad():
            test_loss = SGD(labels, linear_predict(features, w, b)).mean()
            print(f'epoch:{epoch},test_loss={test_loss}')
    print(f'预测的w={w},实际的w={w_acc}\n预测的b={b},实际的b={b_acc}')

在这里插入图片描述
在这里插入图片描述

'''
画图
'''
sample_x = np.linspace(-10,10,10)
w_accurancy = w_acc.flatten().numpy()
b_accurancy = b_acc.flatten().numpy()
sample_y = []
sample_y_pre = []
w_prediction = w.flatten().detach().numpy()
b_prediction = b.flatten().detach().numpy()
for i in range(len(sample_x)):
    sample_y.append(sample_x[i]*w_accurancy[0]+sample_x[i]*w_accurancy[1]+sample_x[i]*w_accurancy[2]+b_accurancy[0])
    sample_y_pre.append(sample_x[i]*w_prediction[0]+sample_x[i]*w_prediction[1]+sample_x[i]*w_prediction[2]+b_prediction[0])
plt.figure(num=0, figsize=(8, 5))
plt.plot(sample_x,sample_y,label="真实函数")
plt.plot(sample_x,sample_y_pre,color='red', linewidth=1.0, linestyle='--',label="预测函数")
plt.show()

真实函数和预测函数比较
在这里插入图片描述
放大后
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
ImportError: libopencv_flann.so.3.4: cannot open shared object file: No such file or directory 是一个常见的错误,表示在导入OpenCV库时找不到所需的共享对象文件。这通常是由于缺少OpenCV库或者库文件路径配置不正确引起的。 要解决这个问题,你可以尝试以下几个步骤: 1. 确保你已经正确安装了OpenCV库。你可以通过在终端中运行以下命令来检查OpenCV是否已经安装: ``` pkg-config --modversion opencv ``` 如果没有输出版本号或者提示找不到命令,说明OpenCV没有正确安装。你可以参考OpenCV官方文档或者使用适合你操作系统的包管理器来安装OpenCV。 2. 检查库文件路径配置是否正确。在终端中运行以下命令,查看OpenCV库文件的路径: ``` pkg-config --libs opencv ``` 输出的结果应该包含正确的库文件路径,例如:-L/usr/local/lib -lopencv_core -lopencv_highgui 等。如果路径不正确,你需要更新库文件路径配置。 3. 如果你已经正确安装了OpenCV库,但仍然遇到该错误,可能是因为系统无法找到库文件。你可以尝试将库文件路径添加到LD_LIBRARY_PATH环境变量中。在终端中运行以下命令: ``` export LD_LIBRARY_PATH=/path/to/opencv/lib:$LD_LIBRARY_PATH ``` 将`/path/to/opencv/lib`替换为你的OpenCV库文件所在的路径。 如果以上步骤都没有解决问题,你可以提供更多关于你的操作系统、OpenCV版本和具体错误信息的细节,以便我能够更好地帮助你。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值