《PyTorch 深度学习实践》第6讲 逻辑斯谛回归

下载数据集操作

import torchvision

# 手写数字数据集。下载到data文件夹下
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True)
# CIFAR-10数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)

逻辑斯谛回归示例:

import torch

# 1、准备数据集(二分类)
x_data = torch.tensor([[1.0], [1.5], [2.0], [3.0]])
y_data = torch.tensor([[0.0], [0.0], [0.0], [1.0]])


# 2、用类设计模型
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)  # 输入维度为1,即特征数。输出维度为1
        # nn.Linear类包含两个成员张量:weight和bias

    def forward(self, x):  # 重写父类forward函数,不能少,每次传参时会自动调用该函数
        y_pre = torch.sigmoid(self.linear(x))  # 传入x计算预测值y_pre
        return y_pre


model = LogisticRegressionModel()  # 实例化模型

# 3、构造损失和优化器
criterion = torch.nn.BCELoss(reduction='sum')  # 以二分类交叉熵作为损失衡量
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 4、训练
for epoch in range(10000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    # print(epoch, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())

x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred:', y_test.data.item())
w= 4.402466297149658
b= -10.868648529052734
y_pred: 0.9988201260566711

如果进行10000轮更新,得出的参数如上。用此模型预测:当学习时间为4h时,考试通过的概率为0.9988。

我们根据该模型,画出学习时间与考试通过概率之间的关系图像:

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 10, 200)
# 下面这行记得加 dtype=torch.float32,否则报错
x_t = torch.tensor(x, dtype=torch.float32).view((200, 1))  # view用法跟reshape类似,转成(200,1)
y_t = model(x_t)  # 根据训练出的模型得出不同学习时间下的考试通过概率
y = y_t.data.numpy()  # 为了绘图,转成numpy类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass Examination')
plt.grid()
plt.show()

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值