第6讲 逻辑斯蒂回归 logistic regression
pytorch学习视频——《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
以下是视频内容笔记以及视频中的手敲源码,可能会有些许改动,笔记纯属个人理解,如有错误勿介或者欢迎路过的大佬指出 。
1. torchversion数据集
# mnist数据集,手写数字识别
mnist_train_set = torchvision.datasets.MNIST(root='./dataset/mnist', train=True, download=False)
mnist_test_set = torchvision.datasets.MNIST(root='./dataset/mnist', train=False, download=False)
# CIFAR-10数据集——飞机、卡车、猫、鸟等
cifar10_train_set = torchvision.datasets.CIFAR10(root='./dataset/cifar10', train=True, download=True)
cifar10_test_set = torchvision.datasets.CIFAR10(root='./dataset/cifar10', train=False, download=True)
2. logistic 函数
(也可以叫sigmoid函数,logistic函数是最典型的sigmoid函数)
-
将变量映射到0-1之间的数
问题描述:根据输入的学习时间,来预测考试是否合格,最终输出是考试通过的概率。(是一个二分类问题)
概率值用logistic函数值来表示,函数中输入的x是原来的y_pred。
logistic函数图像如图所示,饱和函数—— 导数图像服从正态分布。
-
sigmoid函数
-
特点(充分条件)
-
饱和函数
-
单调增
-
有极限
-
-
其它sigmoid函数
-
3. logistic函数使用的一些小细节
-
模型的改变
-
使用了sigmoid函数之后与原来模型的差异,如图所示
-
-
损失函数的改变
-
原来使用的MSE函数,直接比较输出和真实值数值的差异。
-
现在的输出是概率值,需要比较分布的差异
-
KL散度
-
cross-entropy交叉熵
交叉熵公式如图所示(BCE损失):
-
-
-
举个栗子
使用mini-batch损失
4. logistic回归的实现
logistic_regression.py
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [1], [1]])
class LogisticRegressionModel(torch.nn.Module):
# 初始化函数和线性回归模型的一样,因为logistic函数中没有新的参数需要初始化
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
y_pred = F.sigmoid(self.linear(x))
return y_pred
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
if epoch % 50 == 0:
print(epoch, loss.item())
# 将梯度值清0
optimizer.zero_grad()
# 反向传播求梯度
loss.backward()
# 优化器 更新权重
optimizer.step()
# 测试
x = np.linspace(0, 10, 200) # 返回0-10等间距的200个数
x_t = torch.Tensor(x).view((200, 1)) # reshape成一个200行1列的矩阵tensor
y_t = model(x_t) # 传入模型进行测试
# 调用numpy将y_t变成n维数组
y = y_t.data.numpy()
# 图1
plt.plot(x, y)
# 图2——这是y=0.5那条横线
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
结果:
(代码理解请参考注释,以上代码已经在pycharm上经过测试,笔记纯属个人理解,如有错误勿介或者欢迎路过的大佬指出 嘻嘻嘻。)
——未完待续……