线性回归网络的简洁实现

线性回归网络的简洁实现同样有以下几个部分:

  • 生成数据集
  • 读取数据集
  • 定义模型
  • 初始化模型参数
  • 定义损失函数
  • 定义优化算法
  • 训练

与手动实现不同的部分在于读取数据集、定义模型、初始化模型参数、定义损失函数、定义优化算法、训练。

生成数据集

true_w=torch.tensor([2,-3.4])
true_b=4.2
features,labels=data_generator.synthetic_data(true_w,true_b,1000)

def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, len(w)))  # (生成均值为0,方差为1的随机数)
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape(-1, 1)

读取数据集

构造读取数据集函数,并调用框架中现有的API(data.TensorDataset——将数据转换成张量;data.DataLoader——生成小批量样本)来读取数据。

def load_array(data_arrays,batch_size,is_train=True):
    """构造一个pytorch数据迭代器"""
    dataset=data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset,batch_size,shuffle=is_train)

batch_size=10
data_iter=data_iter.load_array((features,labels),batch_size)

定义模型

对于标准深度学习模型,我们可以使用框架的预定义好的层。我们只需要关注使用哪些层来构造模型,不必关注层的实现细节。将模型定义为Sequential类的实例。

在pytorch中全连接层在Linear类中定义。我们将两个参数传入,第一个参数是输入特征的形状,第二个是输出特征的形状。

from torch import nn
net = nn.Sequential(nn.Linear(2,1))

初始化模型参数

我们可以直接访问网络参数以设定它们的初始值。我们通过net[0]选择网络中的第一层,然后使用weight.data和bias.data方法访问参数。使用替换方法normal_和fill_来重写参数值。

net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

定义损失函数

计算均方误差使用的是MSELoss类,其也称为平方L2范数。默认返回所有样本损失的平均值

loss=nn.MSELoss()

定义优化算法

pytorch在optim模块中实现了随机梯度下降算法的许多变体。当我们实例化一个SGD实例时,我们需要优化的参数可以通过net.parameters()获得,小批量梯度下降算法只需要设置学习率。

trainer = torch.optim.SGD(net.parameters(),lr=0.05)

训练

训练过程:在每轮训练中,遍历所有数据,生成小批量的输入和标签。对每个小批量数据进行下述操作:

  • 通过调用net(X)生成预测并计算损失l
  • 通过反向传播来计算梯度
  • 通过调用优化器来更新参数
num_epochs = 3
for epoch in range(num_epochs):
    for X,y in data_iter:
        l=loss(net(X),y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l=loss(net(features),labels)
    print(f'epoch: {epoch+1},loss: {l:f}')

由于损失函数计算的是小批量中样本损失的平均值,所以不用对损失函数求和后再反向传播

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值