05-用PyTorch实现线性回归

使用PyTorch解决问题的四步:

  1. 准备数据集
  2. 设计模型,计算y_pred
  3. 构造Loss、optimizer(损失函数和优化器)
  4. Training cycle(前馈 -> 反馈 -> 更新)

forward:算出损失;backward:求出梯度

一、准备数据集:

需要的数据都得是Tensor(张量)

二、设计模型 构造计算图

所有模型都要继承自Model
最少实现两个成员方法
    构造函数 初始化:__init__()
    前馈:forward()

  1. 先构造计算图
  2. 通过输入x输出y_pred的维度,确定计算图中wb的维度
  3. 计算出loss,然后进行backward

最终的Loss是一个0维的张量,可能需要对向量Loss取平均或求和。

Linear()方法:

torch.nn.Linear(in_features, out_features, bias = True)

in_features:输入x的维度

out_features:输出y的维度

bias:是否有偏置b

三、构造Loss和Optimizer

四、训练迭代

  1. Forward()
  2. 梯度清零
  3. 反馈
  4. 更新优化器
  5. import torch
    
    # mini-batch需要的数据是Tensor
    x_data = torch.Tensor([[1.0], [2.0], [3.0]])
    y_data = torch.Tensor([[2.0], [4.0], [6.0]])
    
    # Design Model  重点目标在于构造计算图
    """
    所有模型都要继承自Model
    最少实现两个成员方法
        构造函数 初始化:__init__()
        前馈:forward()
    Model自动实现backward
    可以在Functions中构建自己的计算块
    """
    class LinearModel(torch.nn.Module):
        def __init__(self):
            super(LinearModel, self).__init__()
            self.linear = torch.nn.Linear(1, 1)      # 构造了一个包含 w和 b的对象
    
        def forward(self, x):
            y_pred = self.linear(x)                 # linear成为了可调用的对象 直接计算forward
            return y_pred
    
    model = LinearModel()               # 创建类的实例
    
    # 3.Construct Loss(MSE (y_pred - y)**2 ) and Optimizer
    # 构造计算图就需要集成Model模块
    criterion = torch.nn.MSELoss(size_average=False)    #需要的参数是y_pred和y
    optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
    
    # 4.Training Cycle
    for epoch in range(100):
        y_pred = model(x_data)          # Forward
        loss = criterion(y_pred, y_data)
        print(epoch, loss.item())
    
        optimizer.zero_grad()           # 梯度清零
        loss.backward()                 # 反馈
        optimizer.step()                # 更新
    
    #Output weight and bias
    print('w = ', model.linear.weight.item())
    print('b = ', model.linear.bias.item())
    
    #Test Model
    x_test = torch.Tensor([[4.0]])
    y_test = model(x_test)
    print('y_pred = ', y_test.data)
    

    结果:

  6. 小知识点:

  7. 如果要使用一个可调用对象,那么在类的声明的时候要定义一个call()函数就OK了,就像这样

    class Foobar:
    	def __init__(self):
    		pass
    	def __call__(self,*args,**kwargs):
    		pass
    

    其中参数*args代表把前面n个参数变成n元组,**kwargsd会把参数变成一个词典,举个例子:

     def func(*args,**kwargs):
     	print(args)
     	print(kwargs)
    
    #调用一下
    func(1,2,3,4,x=3,y=5)
    

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值