学习视频:B站 刘二大人《PyTorch深度学习实践》完结合集
五、逻辑回归(Logistic Regression)
处理分类问题:
计算属于1、2、3…、9的概率为多少
Logistic Regression属于分类问题,输出的是分类的概率值,在训练过程中,计算它属于每一个分类的所有概率,其中概率最大的那一种分类,就是我们要的输出结果。
1. 二分类问题
之前的回归问题是学习时间x与得到的分数y,是预测数字
现在是学习时间x与是否可以通过考试,需要计算y_pred=1的概率
2. Sigmoid function
Sigmoid函数特点:
- 函数值有极限
- 都是增函数
- 饱和函数
Logistic function
Logistic function是最常见的sigmoid function
3. Logistic Regression Model
4. 损失函数
使用二分类交叉熵公式:
交叉熵是对数似然函数的相反数。 对数似然的值我们希望它越大越好,交叉熵的值我们希望它越小越好。
多个样本:
5. 逻辑回归与线性模型的区别
- 多了一个sigmoid函数
- 损失函数不同
其中size_average=True表示loss求平均值,为false表示不求平均
关于loss是否求平均的问题,如果求平均,之后loss的导数也是平均后的,这样loss相对不平均会小,然后loss进行更新的时候会与学习率有关,从而影响学习率的选择。
不求平均,loss的导数值大,loss更新与学习率相关。
很多时候loss不求平均,会在求导数的时候进行平均。
6. 代码实现
import torch.nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
#准备数据集
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[0],[0],[1]])
#定义模型
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel,self).__init__() #调用父类的构造
self.linear = torch.nn.Linear(1,1) #torch.nn.Linear是torch中的一个类,torch.nn.Linear(1,1)构造了一个对象,包括权重和偏置
def forward(self,x):
y_pred =F.sigmoid(self.linear(x))
return y_pred
model = LogisticRegressionModel()
#构造损失函数和优化器的选择
criterion = torch.nn.BCELoss(size_average=False) #需要传入的参数为y_pred,y_data
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
#训练的过程
epoch_list =[]
loss_list=[]
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print(epoch,loss.item())
optimizer.zero_grad() #梯度归零
loss.backward()
optimizer.step() #梯度更新
epoch_list.append(epoch + 1)
loss_list.append(loss.item())
# 画图
plt.plot(epoch_list,loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
运行结果:
测试模型:
处在通过与不通过之前的值为2.5
代码:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0,10,200) #0-10之间有200个点
x_t = torch.Tensor(x).view((200,1)) #将200个点变成(200,1)的张量
y_t = model(x_t) #得到y_t是个张量
y = y_t.data.numpy()#将y_t张量转化为矩阵形式
plt.plot(x,y)
plt.plot([0,10],[0.5,0.5],c='r') #50%的概率
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
运行结果:
参考资料
https://blog.csdn.net/lizhuangabby/article/details/125610051
https://blog.csdn.net/qq_43800119/article/details/126415539?spm=1001.2014.3001.5502