声明:本文参考资料为《动手学深度学习》
1、整体思路
使用Pytorch框架,完成对线性回归的实现
2、详细代码分析
本文与https://blog.csdn.net/weixin_58161464/article/details/116707519进行对比
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
对于数据集的生成,两者无差别
def load_array(data_arrays, batch_size, is_train=True):
"""构造一个PyTorch数据迭代器。"""
dataset = data.TensorDataset(*data_arrays) #类似zip,将features和labels组合在一起
return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((features, labels), batch_size)