简单的一维线性回归-pytorch

代码来自书籍《深度学习入门之pytorch》,书里有几个错误,修改好了


import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable

import numpy as np

x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],
                   [10.088],[5.33],[8.5],[4.2]],dtype=np.float32)
y_train = np.array([[1.7],[2.8],[2.3],[4.0],[6.3],[5.3],[6.2],[8.6],[9.9],[2.8],[6.3],
                    [11.8],[8.5],[6.5],[3.8]],dtype=np.float32)
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

#define a LinearRegression Class
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1,1)

    def forward(self, x):
        out = self.linear(x)
        return out


#create a model
if torch.cuda.is_available():
    print('GPU1')
    model = LinearRegression().cuda()
else:
    print('CPU1')
    model = LinearRegression()

#define a optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=1e-3)

#start to train model
num_epoch = 25000
for epoch in range(num_epoch):
    if torch.cuda.is_available():
        #print('GPU2')
        inputs = Variable(x_train).cuda()
        target = Variable(y_train).cuda()
    else:
        #print('CPU2')
        inputs = Variable(x_train)
        target = Variable(y_train)

    #forward
    out = model(inputs)
    loss = criterion(out,target)

    #backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1)%20 ==0:
        print('Epoch[{}/{}],loss:{:.6f}'.format(epoch+1,num_epoch,loss.data[0]))

model.eval()
if torch.cuda.is_available():
    print('GPU3')
    predict = model(Variable(x_train).cuda())
    predict = predict.data.cpu().numpy()
else:
    print('CPU3')
    predict = model(Variable(x_train))
    predict = predict.data.numpy()

plt.plot(x_train.numpy(),y_train.numpy(),'ro',label='Original data')
plt.plot(x_train.numpy(),predict,label='Predict Line')
plt.show()

运行结果:

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值