pytorch实现简单线性模型
一、数据集的预处理
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])
使用 torch.tensor 函数对数据进行预处理。
二、线性模型的构建
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.
原创
2020-08-14 14:47:32 ·
1128 阅读 ·
0 评论