3.3 线性回归的简洁实现
随着深度学习框架的发展,开发深度学习应用变得越来越便利。实践中,我们通常可以用比上一节更简洁的代码来实现同样的模型。在本节中,我们将介绍如何使用PyTorch更方便地实现线性回归的训练。
3.3.1 生成数据集
我们生成与上一节中相同的数据集。其中features
是训练数据特征,labels
是标签。
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)
Copy to clipboardErrorCopied
3.3.2 读取数据
PyTorch提供了data
包来读取数据。由于data
常用作变量名,我们将导入的data
模块用Data
代替。在每一次迭代中,我们将随机读取包含10个数据样本的小批量。
import torch.utils.data as Data
batch_size = 10
# 将训练数据的特征和标签组合
dataset = Data.TensorDataset(features, labels)
# 随机读取小批量
data_iter = Data.DataLoader(dataset, batch_size