基于深度学习框架的线性回归

使用深度学习框架来简介实现线性回归模型,生成数据集

导入相关数据包

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l

生成人工数据集

我们使⽤线性模型参数w = [2, −3.4]⊤、b = 4.2 和噪声项ϵ⽣成数据集及其标签:

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)#生成数据,类型与true_w 和true_b的数据一样

# print(features)
# print(labels)

调⽤框架中现有的API来读取数据

def load_array(data_arrays, batch_size, is_train=True): #@save
    """构造⼀个PyTorch数据迭代器"""
    #将features, labels传入到TensorDataset当中
    dataset = data.TensorDataset(*data_arrays)
    #拿到数据集之后调用DataLoader函数,dataset:读入的数据,batch_size:每次读取数据数,shuffle=is_train:是否进行随机打乱数据顺序
    
    return data.DataLoader(dataset, batch_size, shuffle=is_train)
    #导入数据之后调用DataLoader函数,从数据集当中随机抽取batch_size个数据

batch_size = 10
data_iter = load_array((features, labels), batch_size)

next(iter(data_iter))

使用框架的预定义好的层

# nn是神经⽹络的缩写
from torch import nn

net = nn.Sequential(nn.Linear(2, 1))#定义输入的维度是2,输出的维度是1

    #Sequential是一个有序的容器,神经网络模块按照在传入构造器的顺序依次被添加到计算图中执行
    #同时以神经网络模块为元素的有序字典也可以作为参数进行传入

初始化模型参数

#我们在构造nn.Linear时指定输⼊和输出尺⼨⼀样,现在我们能直接访问参数以设定它们的初始值。我
#们通过net[0]选择⽹络中的第⼀个图层,然后使⽤weight.data和bias.data⽅法访问参数。我们还可
#以使⽤替换⽅法normal_和fill_来重写参数值。
net[0].weight.data.normal_(0, 0.01)#使用正态分布来替换data的值
net[0].bias.data.fill_(0)#设置偏差的值为0

计算均⽅误差使⽤的是MSELoss类,也称为平⽅L2范数

loss = nn.MSELoss()

定义优化算法,实例化SGD

trainer = torch.optim.SGD(net.parameters(), lr=0.03)#parameters当中包括所有的参数,lr是学习率

训练模型

num_epochs = 3#设置迭代周期
for epoch in range(num_epochs):
    for x, y in data_iter:
        l = loss(net(x) ,y)#拿到预测的值与真实的值y做loss
        trainer.zero_grad()#梯度清0
        l.backward()
        trainer.step()#调用step函数进行模型的更新
    l = loss(net(features), labels)#将所有的features放入net当中与labels做loss
    print(f'epoch {epoch + 1}, loss {l:f}')
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值