Pytorch第四讲
使用PyTorch框架实现简单线性模型
总体步骤
步骤一
构建数据集(Prepare dataset)
步骤2
设计模型(Design model using Class) 注:这里的模型是指用来求解y-hat
步骤3
构造损失函数和优化器(Construct loss and optimizer)
步骤4
编写训练过程(Training cycle: forward, backward, update)
总结
1.求解Y-hat
2.求解loss
3.backward
4.更新权重
实现代码
import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[1.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__() #调用父类构造
self.linear = torch.nn.Linear(1, 1) #构造对象 其中第一个参数1为输入(这里指x)的维度,第二个参数1为输出(这里指y)的维度
def forward(self, x):
y_pred = self.Linear(x)
return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
for epoch in range(100):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()