一,引入库和使用工具
import torch.nn as nn
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import datetime
版本
python 3.7
torch == 1.9.0
matplotlib == 3.3.4
二、利用pytorch生成一个随机张量?
# 选取200个x轴坐标,-1到1直接的等差数列
x = torch.unsqueeze(torch.linspace(-1, 1, 200), dim=1)
# 随机选取200个y轴坐标
y = 5 * x + 0.8 * torch.rand(x.size())
# 将x,y添加到Variable
X = Variable(x)
Y = Variable(y)
三、创建模型
代码如下:
# 迭代次数
epoch = 1000
# 学习率
learning_rute = 0.0001
# 定义模型
model = nn.Linear(1, 1)
# 定义losss损失
square_loss = nn.MSELoss(reduction='sum')
# 创建优化器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rute)
四.开始运行
# 让matplotlib进入交互
plt.ion()
for i in range(epoch):
# 计算预测值
y_hat = model(X)
# 计算损失
loss = square_loss(y_hat, Y)
# 每迭代100次,打印一次损失
if (i + 1) % 100 == 0:
print(loss)
# 梯度置零
optimizer.zero_grad()
# 反向传播
loss.backward()
# 参数更新
optimizer.step()
# matplotlib窗口清空
plt.cla()
# 绘制x,y坐标点
plt.scatter(X.data.numpy(), Y.data.numpy())
# 绘制直线
plt.plot(X.data.numpy(), y_hat.data.numpy(), 'r-', lw=2)
# 暂停0.05s,用于观察
plt.pause(0.05)
# 关闭交互
plt.ioff()
五.运行结果
tensor(106.4726, grad_fn=<MseLossBackward>)
tensor(16.7484, grad_fn=<MseLossBackward>)
tensor(10.7933, grad_fn=<MseLossBackward>)
tensor(10.3978, grad_fn=<MseLossBackward>)
tensor(10.3715, grad_fn=<MseLossBackward>)
tensor(10.3698, grad_fn=<MseLossBackward>)
tensor(10.3696, grad_fn=<MseLossBackward>)
tensor(10.3696, grad_fn=<MseLossBackward>)
tensor(10.3696, grad_fn=<MseLossBackward>)
tensor(10.3696, grad_fn=<MseLossBackward>)
六.全部代码
import torch.nn as nn
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import datetime
t_start = datetime.datetime.now()
x = torch.unsqueeze(torch.linspace(-1, 1, 200), dim=1)
print(x.size())
y = 5 * x + 0.8 * torch.rand(x.size())
X = Variable(x)
Y = Variable(y)
epoch = 1000
learning_rute = 0.0001
model = nn.Linear(1, 1)
square_loss = nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rute)
plt.ion()
for i in range(epoch):
y_hat = model(X)
loss = square_loss(y_hat, Y)
if (i + 1) % 100 == 0:
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
plt.cla()
plt.scatter(X.data.numpy(), Y.data.numpy())
plt.plot(X.data.numpy(), y_hat.data.numpy(), 'r-', lw=2)
plt.pause(0.05)
plt.ioff()
print('\n',datetime.datetime.now() - t_start)
白嫖不好,创作不易。各位的点赞就是我创作的最大动力,如果我有哪里写的不对,欢迎评论区留言进行指正。各位老铁要是觉得这篇文章很有用的话,麻烦手抖点个赞,大家的支持是我前进的动力,后期将发布更好的作品,文章有哪些错误,也麻烦各位大佬纠正,谢谢各位!