刘二大人《PyTorch深度学习实践》pytorch实现线性回归

05.用PyTorch实现线性回归_哔哩哔哩_bilibili

使用pytorch框架后虽然比上次写的复杂一些,但是可以方便 扩展成更复杂的神经网

用pytorch提供的工具实现线性模型训练过程

1.准备数据

2.设计模型(计算\hat{y}),继承nn.Module

3.设置loss函数和优化器,

4.开始训练(forward,backward,update)

Pytorch不再人工求导数,主要构造计算图             

确定输入x和输出\hat{y}的维度,这样就知道w和b的维度

\hat{y}维度可以很大,但是loss必须是标量 

把模型定义成一个类,以后不管什么模型都是这是形式:

class Model(torch.nn.Module)

至少有两个方法_init_和forward

torch.nn.Linear :自动包含w和b         

 \begin{bmatrix}y_{pred}^{(1)} \\ y_{pred}^{(2)} \\ y_{pred}^{(3)} \end{bmatrix}=w\cdot \begin{bmatrix}x^{(1)} \\ x^{(2)} \\ x^{(3)} \end{bmatrix}+b

def _call_(self,*args,**kwargs); #用于参数维度不定,很灵活

*args 可变参数传递

**kwargs 参数变成词典

如torch.optim.SGD 自动帮你找到要更新的权重

代码:

import matplotlib.pyplot as plt
import torch

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

#用pytorch框架随机梯度下降拟合y=wx

class LinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear=torch.nn.Linear(1,1)
        # (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的
        # 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.bias

    def forward(self,x):
        y_pred=self.linear(x)  #_call()_ 可调用的对象
        return y_pred

model=LinearModel()

criterion = torch.nn.MSELoss(size_average=True) #求loss
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)  #学习率为0.01 随机梯度下降


epoch_list=[]
cost_list=[]

x_test=torch.Tensor([[4.0]])
print('predict(before training)','x=4 y=',model(x_test).item())  ###训练前对x=4.0的y值预测

for epoch in range(100):
        y_pred=model(x_data)
        loss_val=criterion(y_pred,y_data)  #求MSEloss
        print('epoch:', epoch,  'loss=', loss_val.item())

        loss_val.backward()                #反向传播
        optimizer.step()        #更新权值
        optimizer.zero_grad()   #清零

        epoch_list.append(epoch)
        cost_list.append(loss_val.item())  # loss

print('w=',model.linear.weight.item())  #w
print('b=',model.linear.bias.item())    #b


print('predict(after training)','x=4 y=',model(x_test).item())  ###训练后对x=4.0的y值预测

plt.plot(epoch_list,cost_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()


if __name__ == "__main__":
    main()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值