线性回归(MXNet)

(博客为《动手学深度学习》P35-38的学习笔记)

一、生成数据集

特征数为2,训练数据集大小为1000。

from mxnet import autograd, nd
nums_input = 2
nums_example = 1000
true_w = [2, -3.4]
true_b = 4.2
features = nd.random.normal(scale=1, shape=(nums_example, nums_input))
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)  # 给标签添加噪声

二、读取数据集

通过MXNet中Gluon提供的data包,读取数据集。

from mxnet.gluon import data as gdata

将训练数据集的特征和标签组成训练数据集的实例。

dataset = gdata.ArrayDataset(features, labels)

batch_size为10,将训练数据集分成多个mini-batch。

batch_size = 10
data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True)

三、 定义模型

通过MXNet中Gluon提供的nn包,定义模型。

from mxnet.gluon import nn

创建神经网络容器实例,可以串联各个层次。

net = nn.Sequential()

在容器中添加一个全连接层,且该层输出个数为1。

net.add(nn.Dense(1))

四、初始化模型参数

通过MXNet中init包,生成初始的模型参数。

from mxnet import init

初始化模型参数。

net.initialize(init.Normal(sigma=0.01))

五、定义损失函数

通过MXNet中Gluon的loss包,定义损失函数。

from mxnet.gluon import loss as gloss

创建损失函数,该模型使用L2范数损失。

loss = gloss.L2Loss()

六、定义优化算法

通过MXNet中Gluon包,定义优化器。

from mxnet import gluon

 设置优化器的参数,‘sgd’为随机梯度下降,学习率为0.03。其中,collect_params()用于收集模型参数。

trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})

七、训练模型

迭代次数为3。

nums_epochs = 3
for epoch in range(nums_epochs):
    for X, y in data_iter:
        with autograd.record():
            l = loss(net(X), y)
        l.backward()
        trainer.step(batch_size)  # 更新训练后的参数
    l = loss(net(features), labels)
    print("epoch {}, loss {}".format(epoch + 1, l.mean()))

# 输出训练后的参数
dense = net[0]
print(dense.weight.data(), dense.bias.data())

输出结果为:

epoch 1, loss
[0.03510439]
<NDArray 1 @cpu(0)>
epoch 2, loss
[0.00012555]
<NDArray 1 @cpu(0)>
epoch 3, loss
[4.845141e-05]
<NDArray 1 @cpu(0)>

[[ 1.9997343 -3.4001033]]
<NDArray 1x2 @cpu(0)>
[4.1997924]
<NDArray 1 @cpu(0)>

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

学习啊ZzZ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值