用PyTorch搭建一个线性模型

本文介绍了如何使用Python的PyTorch库,通过读取CSV数据,构建线性模型,构造优化器并利用梯度下降法进行训练,最小化均方误差,最终评估模型性能的过程。
摘要由CSDN通过智能技术生成

目标

根据100个样本数据找出合适的𝑤和𝑏,使得 𝑦=𝑤𝑥+𝑏。

步骤

1.读取数据
2.构造一个线性模型
3.构造优化器
4.最小化方差(训练)
5.性能评估

实施过程

import pandas as pd
import torch
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"     # 解决中文乱码问题

# 1-读取数据
data = pd.read_csv('line_fit_data.csv').values         # 将数据读取后存储为ndarray格式
X = torch.tensor(data[:, 0], dtype=torch.float32)      # 样本自变量,并将数据类型转换成浮点型float
y = torch.tensor(data[:, 1], dtype=torch.float32)      # 目标变量

# 初始权值和阈值
W = torch.tensor(-8.0, requires_grad=True)     # 初始权值为-10.0,并且设置需要修正
b = torch.tensor(5.0, requires_grad=True)  # 初始阈值
learning_rate = 0.25                              # 学习速率

# 2-构建模型及损失函数
def linear_model(W, X, b):  # 构建的线性模型为:y=WX+b
    return W*X+b

def loss_fn( y_pre, y_true):             # 构造损失函数
    return ((y_true - y_pre)**2).mean()  # 均方误差

plt.figure(figsize=(10, 5))                                # 设置画布大小
plt.axis([-0.01, 2, -2, 8])                               # 指定坐标轴的取值范围
plt.scatter(X, y, color='blue')                            # 绘制样本实际分布图(散点图)
plt.plot(X, linear_model(W, X, b).data, color='red') # 绘制模型预测结果分布
plt.legend(['target_y', 'predicted'])                      # 设置图例

for Round in range(200): # 多轮优化
    plt.plot(X, linear_model(W, X, b).data, color='red', alpha=0.2)  # 动态绘制模型预测结果分布
    plt.pause(0.4)  # 设置停留时长

    # 3-构建优化器
    y_pre = linear_model(W,X,b)    # 前向传播
    loss = loss_fn(y_pre, y)       # 模型损失值
    loss.backward()                # 误差反向传播(计算梯度),以此改变W和b

    # 4-最小化方差,目的是更新权值&阈值
    W.data = W.data - W.grad * learning_rate   # 沿着梯度的反方向更新权值
    b.data = b.data - b.grad * learning_rate   # 沿着梯度的反方向更新阈值

    W.grad.zero_()        # 将权值的梯度清零
    b.grad.zero_()        # 将阈值的梯度清零

    print('Round:', Round, ' Loss:', loss.item(), 'W:', W.item(), ' b:', b.item())

# 5.性能评估
plt.figure(figsize=(10, 5))       # 设置画布大小
plt.axis([-0.01, 2, -2, 8])      # 指定坐标轴的取值范围
plt.scatter(X, y, color='blue')   # 绘制样本实际分布图
plt.plot(X, linear_model(W, X, b).data, color='red')     # 绘制模型预测结果分布
plt.legend(['target_y', 'predicted'])                    # 设置图例
plt.show()
  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值