import torch
import torch.nn as nn
X = torch.tensor([[1], [2], [3], [4]], dtype=torch.float32)
Y = torch.tensor([[2], [4], [6], [8]], dtype=torch.float32)
X_test = torch.tensor([5], dtype=torch.float32)
n_samples, n_features = X.shape
print(n_samples, n_features)
input_size = n_features
output_size = n_features
# model = nn.Linear(input_size, output_size)
class LinearRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegression, self).__init__()
# define layers
self.lin = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.lin(x)
model = LinearRegression(input_size, output_size)
print(f"Prediction before training: f(5) = {model(X_test).item():.3f}")
# Training
learning_rate = 0.01
n_iters = 10
loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(n_iters):
# prediction = forward pass
y_pred = model(X)
# loss
l = loss(Y, y_pred)
# gradients
l.backward()
# update weights
optimizer.step()
# zero gradients
optimizer.zero_grad()
if epoch % 2 == 0:
[w, b] = model.parameters()
print(f"epoch {epoch + 1}, weight: {w[0][0].item():.3f}, loss: {l:.8f}")
print(f"Prediction after training: f(5): {model(X_test).item():.3f}")
10-03
2267
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
11-12
1383
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
03-18
2338
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
“相关推荐”对你有帮助么?
-
非常没帮助
-
没帮助
-
一般
-
有帮助
-
非常有帮助
提交