前言:
这是来自于我的python深度学习的第五章的知识点总结,部分知识点是基于前几篇博客而增加的,希望我的这篇博客可以帮到大家。
注:有许多内容来自于课件PPT及网络摘抄。
学习目标:
1.线性回归简介
2.logistic回归简介
3.用PyTorch实现 Logistic回归
学习内容:
一、线性回归简介
线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法,运用十分广泛。其表达形式为y = w'x+e,e为误差服从均值为0的正态分布。 回归分析中,只包括一个自变量和一个因变量,且二者的关系可用一条直线近似表示,这种回归分析称为一元线性回归分析。如果回归分析中包括两个或两个以上的自变量,且因变量和自变量之间是线性关系,则称为多元线性回归分析。
最小二乘法
最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。最小二乘法还可用于曲线拟合,其他一些优化问题也可通过最小化能量或最大化熵用最小二乘法来表达 。最小二乘法因其原理简单、收敛速度较快、易于理解和实现而被广泛应用于参数估计中。
二、logistic回归简介
logistic回归又称logistic回归分析,是一种广义的线性回归分析模型,常用于数据挖掘,疾病自动诊断,经济预测等领域。 逻辑回归根据给定的自变量数据集来估计事件的发生概率,由于结果是一个概率,因此因变量的范围在 0 和 1 之间。
(1)回归vs分类
三、用PyTorch实现 Logistic回归
1、数据准备; 2、线性方程; 3、激活函数; 4、损失函数; 5、优化算法; 6、模型可视化
我们这里用一个练习题更好的理解如何用PyTorch实现 Logistic回归:
以下是此练习的代码及注释。
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# 1.准备数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
class LogisticModel(torch.nn.Module):
def __init__(self):
super(LogisticModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) # 定义线性层,输入维度为1,输出维度为1
def forward(self, x):
y_pred = F.sigmoid(self.linear(x)) # 使用Sigmoid函数将线性层的输出映射到0-1之间,得到预测值
return y_pred
model = LogisticModel() # 实例化逻辑回归模型
criterion = torch.nn.BCELoss(size_average=False) # 定义二元交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 定义随机梯度下降优化器,学习率为0.01
for epoch in range(1000): # 进行1000轮训练
y_pred = model(x_data) # 使用模型对输入数据进行预测
loss = criterion(y_pred, y_data) # 计算预测值与真实值之间的损失
print(epoch, loss.item()) # 打印当前轮数和损失值
optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新模型参数
x = np.linspace(0, 10, 200) # 生成0到10之间的200个等间距的点
x_t = torch.Tensor(x).view((200, 1)) # 将numpy数组转换为PyTorch张量,并调整形状为(200, 1)
y_t = model(x_t) # 使用模型对输入数据进行预测
y = y_t.data.numpy() # 将预测结果转换为numpy数组
plt.plot(x, y) # 绘制预测结果曲线
plt.plot([0, 10], [0.5, 0.5], c='r') # 绘制红色水平线,表示阈值为0.5
plt.xlabel('Hour') # 设置x轴标签为'Hour'
plt.ylabel('Pass') # 设置y轴标签为'Pass'
plt.grid() # 显示网格线
plt.show() # 显示图像
以下为上面代码运行结果: