PyTorch 入门与实践(一)逻辑斯蒂回归

来自 B 站刘二大人的《PyTorch深度学习实践》P6 的学习笔记

Logistic Function(Sigmoid):

σ ( x ) = 1 1 + e − x , 可 知 σ ( x ) ∈ ( 0 , 1 ) , 用 于 概 率 预 测 。 \sigma(x) = \frac{1}{1+ e^{-x}},可知 \sigma(x) \in (0, 1),用于概率预测。 σ(x)=1+ex1σ(x)(0,1)
1
各种 Sigmoid 函数:
2
只需在线性回归模型最后加入 Sigmoid 输出层即可做逻辑斯蒂回归预测:
3

二分类问题的损失函数

线性回归问题可以用距离来衡量预测结果;

但是分类问题需要使用 KL 散度或交叉熵做概率预测。
4
5
二分类问题的交叉熵损失:
Mini-Batch Loss Function for Binary Classification(BCE Loss):
l o s s = − 1 N ∑ n = 1 N y n log ⁡ y ^ n + ( 1 − y n ) log ⁡ ( 1 − y ^ n ) loss = -\frac{1}{N} \sum^N_{n=1} y_n \log \hat y_n + (1 - y_n) \log (1 - \hat y_n) loss=N1n=1Nynlogy^n+(1yn)log(1y^n)
6

Implementation of Logistic Regression

以下是使用 PyTroch 建模和训练神经网络的四个一般步骤:
7
模拟的数据是学习时间和考试通过率,得到的输出的分布是拟合 Sigmoid 函数的;

可以观察到大于 3 小时的学习时间考试通过率才会大于 0.5,这与训练数据 x ≥ 3.0 , y > 0 x \ge 3.0,y \gt 0 x3.0y>0 是吻合的。
8
经本人测试的完整代码:

import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
import matplotlib.pyplot as plt
from torchvision import datasets

# train_set = datasets.MNIST(root="../datasets/mnist", train=True, download=True)
# test_set = datasets.MNIST(root="../datasets/mnist", train=False, download=True)

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])


class LogisticRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        y_pred = F.sigmoid(self.linear(x))
        return y_pred


model = LogisticRegressionModel()

criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

#loss_lst = []  # 保存loss值以便可视化

for epoch in range(10000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print("epoch:", epoch, "loss:", loss.item())
    #loss_lst.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# np.savetxt("./loss_lst.txt", loss_lst)

x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.hlines(0.5, x[0], x[-1], c="r")
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Skr.B

WUHOOO~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值