pytorch实现简单线性回归模型
目录:完整注释在代码中思路生成原始数据生成小批量数据定义模型(网络)定义误差函数定义优化算法主函数传参训练完整代码
思路
生成原始数据
根据原始数据以及batch_size,生成小批量数据
定义模型(网络)
定义误差函数
定义优化算法
主函数传参训练
生成原始数据
def synthetic_data(w, b, num_examples):
""" 生成 y = Xw + b + 噪声。"""
# torch.normal(mean, std, (x, y)):返回一个shape为(x
原创
2022-01-18 18:37:04 ·
1184 阅读 ·
0 评论