【机器学习】007_线性回归模型Part.5_线性回归模型简洁实现

通过深度学习框架来简洁地实现线性回归模型

1. 初始化

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l

# 初始化w, b的真实值,定义特征值与标签数
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

2. 调用框架中现有的API来读取数据

获取数据集,并获取小批量数据集的数据迭代器

def load_array(data_arrays, batch_size, is_train=True):
    # 构造一个pytorch数据迭代器
    # 已知features和labels,可以直接调用TensorDataset函数来将特征数据与标签数据直接生成数据集
    dataset = data.TensorDataset(*data_arrays)
    # 拿到数据集之后,可以调用DataLoader每次从中拿出小批量数据,并打乱标签
    # 最后返回一个数据迭代器对象
    return data.DataLoader(dataset, batch_size, shuffle=True)

batch_size = 10
# 调用load_array函数将特征数据与标签数据传入,得到一个数据迭代器对象
data_iter = load_array((features, labels), batch_size)
# 使用next可以获取数据迭代器的下一个批量数据
next(iter(data_iter))

3. 线性回归模型

# 使用框架预定义的线性层
# nn是神经网络的缩写
from torch import nn
# 指定输入维度和输出维度
net = nn.Sequential(nn.Linear(2, 1))
# 初始化模型参数
# net[0]访问Linear,weight访问w,data访问w的值,normal_可以将w设置均值为0,方差0.01正态分布
net[0].weight.data.normal_(0, 0.01)
# 同样地将方差设置为0
net[0].bias.data.fill_(0)

4. 直接调用平方损失和梯度算法

计算均方误差使用的是 MSELoss 类,也称为平方范数

实例化 SGD 实例

# 直接调用平方损失和梯度算法
loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.03)

5. 训练过程

# 训练模型
num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        # 首先计算模型的预测值net(X),并将其与真实标签y传入损失函数loss中计算损失l
        # net里面自带了模型参数,因此不再需要将w和b的值传入了
        l = loss(net(X), y)
        # 调用trainer.zero_grad()将模型参数的梯度清零,以便进行下一次迭代时重新计算梯度
        trainer.zero_grad()
        # 调用l.backward()进行反向传播,计算损失对模型参数的梯度
        l.backward()
        # 调用trainer.step()根据梯度更新模型参数
        trainer.step()
    # 使用loss(net(features), labels)计算整个数据集上的损失l
    l = loss(net(features), labels)
    print(f'epoch {epoch+1}, loss {l:f}')

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值