Pytorch学习笔记 Task2.2 pytorch实现线性回归

线性回归能够用一个直线较为精准的描述数据之间的关系。当新的数据出现的时候,就能够预测出一个简单的值。

1 pytorch 实现线性回归

from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
""" 
Pytorch Dataset/TensorDataset和Dataloader 
https://www.jianshu.com/p/3fa75db88387
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

# torch.device代表将torch.Tensor分配到的设备的对象。
# https://ptorch.com/news/187.html#torch-device
device = torch.device('cpu')

# torch.unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度
# https://blog.csdn.net/xiexu911/article/details/80820028
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 0.2 * x + 0.2 * torch.rand(x.size())
x = x.to(device)
y = y.to(device)

BATCH_SIZE = 4
torch_dataset = TensorDataset(x, y)
loader = DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)


# class Net(torch.nn.Module):
#     def __init__(self, n_features, n_hidden, n_output):
#         super(Net, self).__init__()
#         self.hidden = torch.nn.Linear(n_features, n_hidden)
#         self.predict = torch.nn.Linear(n_hidden, n_output)

#     def forward(self, x):
#         x = F.relu(self.hidden(x))
#         x = self.predict(x)
#         return x

# n_input = 1
# n_hidden = 10
# n_output = 1
# net = Net(n_input, n_hidden, n_output).to(device)
# print(net)

class LinearRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        out = self.linear(x)
        return out


net = LinearRegression().to(device)
print(net)

# 开启画图交互模式
plt.ion()

learning_rate = 0.5
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
loss_func = torch.nn.MSELoss()

epochs = 100
for epoch in range(epochs):
    prediction = net(x)
    loss = loss_func(prediction, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 5 == 0:
        # plt.cla()
        fig, ax = plt.subplots()
        ax.scatter(x.data.numpy(), y.data.numpy())
        ax.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        ax.text(0.5, 0, "loss=%.4f" % loss.item(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)

# 结束画图交互模式
plt.ioff()
plt.show()

2 参考资料

1. 小黑的Python日记
2. pytorch管理数据类型属性
3. torch.squeeze() 和torch.unsqueeze()的用法

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值