Pytorch学习之:搭建一个回归网络并训练、测试

import torch
from torch import nn
import numpy as np

class LinearModel(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        # 定义线性层
        self.linear_layer = nn.Linear(input_dim, output_dim)

    # 定义前向传播的方式
    def forward(self,x):
        output = self.linear_layer(x)
        return output




def train(model,input_data, labels):
    learning_rate = 0.01
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    # 把数据搬到 GPU 上(如果没有 GPU 就放在 cpu 上)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # 开始进行训练过程
    for epoch in range(5000):

        x_train = torch.from_numpy(input_data).to(device)
        y_train = torch.from_numpy(labels).to(device)

        # 梯度要清零
        optimizer.zero_grad()
        outputs = model(x_train)

        # 计算损失
        loss = criterion(outputs, y_train)
        # 反向传播
        loss.backward()
        # 权重更新
        optimizer.step()

        if epoch % 500 == 0:
            print("epoch{}, loss:{} ".format(epoch, loss.item()))


def test(model):
    lst = [6,7,8,10]
    ts = torch.from_numpy(np.array(lst,dtype=np.float32).reshape(-1,1))
    x = model(ts)
    print(x.data.numpy())


if __name__ == '__main__':
    # 创建训练数据,注意维度的组织顺序
    input_data = np.array([i for i in range(10)],dtype=np.float32).reshape(-1,1)
    print(input_data.shape)
    # 创建训练的标签
    labels = np.array([2*i + 5 for i in range(10)],dtype=np.float32).reshape(-1,1)
    print(labels.shape)
    # 实例化网络
    model = LinearModel(1,1)
    # 对网络进行训练
    train(model,input_data,labels)
    # 测试网络的拟合效果
    test(model)


在这里插入图片描述

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

暖仔会飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值