1. 线性回归/非线性回归代码

import numpy as np
import matplotlib.pyplot as plt
from torch import nn, optim
from torch.autograd import Variable
import torch

# 线性回归,造训练数据 y = 0.1x + 0.2 + noise
x_data = np.random.rand(100) # ndarray{Size: 100)
noise = np.random.normal(0, 0.01, x_data.shape)
y_data = x_data * 0.1 + 0.2 + noise
x_data = x_data.reshape(-1, 1) # ndarray{Size: (100,1)) 变成vector
y_data = y_data.reshape(-1, 1)

# 非线性回归, y = x^2 + noise
x_data = np.linspace(-2,2,200)[:,np.newaxis] # ndarray{Size: (200,1)) 另一种变成vector的方式
noise = np.random.normal(0,0.2,x_data.shape)
y_data = np.square(x_data) + noise

plt.scatter(x_data, y_data)
plt.show()

# 把numpy数据变成tensor
x_data = torch.FloatTensor(x_data)
y_data = torch.FloatTensor(y_data)
inputs = Variable(x_data)
target = Variable(y_data)

# 线性回归
# 构建神经网络模型
# 一般把网络中具有可学习参数的层放在__init__()中
class LinearRegression(nn.Module):
    # 定义网络结构
    def __init__(self):
        # 初始化nn.Module
        super(LinearRegression, self).__init__()
        self.fc = nn.Linear(1, 1)

    # 定义网络计算
    def forward(self, x):
        out = self.fc(x)
        return out

# 非线性回归
class NonLinearRegression(nn.Module):
    # 定义网络结构
    def __init__(self):
        # 初始化nn.Module
        super(NonLinearRegression, self).__init__()
        # units:1-10-1
        self.fc1 = nn.Linear(1,10)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(10,1)
        
    # 定义网络计算
    def forward(self,x):
        x = self.fc1(x)
        x = self.tanh(x)
        x = self.fc2(x)
        return x


# 定义模型
model = LinearRegression()
# 定义代价函数
mse_loss = nn.MSELoss()
# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.1) #非线性回归学习率定大点,lr=0.3


for i in range(1001): #非线性回归迭代次数定大点,range(2001)
    out = model(inputs)
    # 计算loss
    loss = mse_loss(out, target)
    # 梯度清0
    optimizer.zero_grad()
    # 计算梯度
    loss.backward()
    # 修改权值
    optimizer.step()
    if i % 200 == 0:
        print(i, loss.item())

y_pred = model(inputs)
plt.scatter(x_data, y_data)
plt.plot(x_data.data.numpy(), y_pred.data.numpy(), 'r-', lw=3)
plt.show()

reference:
Qinbf github pytorch入门课程

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值