torch 实现线性回归

torch 实现线性回归的写法

重写 Dataset 类以及定义线性回归网络

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

random_seed = 123
torch.manual_seed(random_seed)

class linear_regression(nn.Module):
    """
    模型类: 线性回归
    """
    def __init__(self, input_dim):
        super(linear_regression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, X):                       # X:(bzs, input_dim)
        output = self.linear(X)
        return output

class myDataset(Dataset):
    """
    torch 数据集类 重写
    """
    def __init__(self, X, y):
        self.data = X
        self.label = y

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)


def data_prepare(X, w, b):
    """
    生成线性分布,并添加高斯噪声
    """
    y = torch.matmul(X, w) + b
    # y += torch.rand_like(y) * 0.001           # (0,0.001]的均匀分布
    y += torch.normal(0, 0.01, size=y.size())   # 均值为0,方差为0.01的正态分布
    return y.reshape(-1, 1)                      # 和 net(X) 的结果对齐       1*bsz -> bsz*1

num_data = 100
input_dim = 3
X = torch.rand(num_data, input_dim)             # (0,1] 均匀分布
# X = torch.randn(num_data, input_dim)          # 均值为0,方差为1的正态分布

true_w, true_b = torch.tensor([2.0, 4.4, 5.0]), -1
y = data_prepare(X, true_w, true_b)

data = myDataset(X, y)
# print(data[0])

bsz = 10
dataloader = DataLoader(data, batch_size=bsz)

net = linear_regression(input_dim)
loss = nn.MSELoss()
lr = 0.03
optim = torch.optim.SGD(net.parameters(), lr=lr)
# optim = torch.optim.Adam(net.parameters(), betas=(0.9, 0.999), lr=lr)
num_epoch = 100
for epoch in range(num_epoch):
    print("epoch:", epoch)
    for X, y in dataloader:
        l = loss(net(X), y)
        optim.zero_grad()
        l.backward()
        optim.step()
    print("loss:", l)

print(net.state_dict())

参考:动手学深度学习

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值