逻辑回归小例子

这篇博客介绍了线性回归和逻辑回归两种常见的机器学习模型。线性回归用于预测连续数值,通过最小化均方误差损失函数进行训练。而逻辑回归则用于二分类任务,采用sigmoid激活函数和二分类交叉熵损失函数。文中还展示了在MNIST数据集上使用PyTorch实现逻辑回归的代码示例,并绘制了模型预测概率曲线。
摘要由CSDN通过智能技术生成

线性回归:
仿射模型: y ^ = x ∗ ω + b \hat{y}=x*\omega+b y^=xω+b
损失函数: l o s s = ( y ^ − y ) 2 = ( x ∗ ω − y ) 2 loss=(\hat{y}-y)^2=(x*\omega-y)^2 loss=(y^y)2=(xωy)2
分类任务:MNIST数据集
The database of handwritten digits

  • Training set: 60,000 examples,
  • Test set: 10,000 examples.
  • Classes: 10
import torchvision
train_set = torchvision.datasets.MNIST(root='../dataset/mnist', train=True, download=True)
test_set = torchvision.datasets.MNIST(root='../dataset/mnist', train=False, download=True)

在这里插入图片描述
在分类任务中,模型的输出是输入精确到某一类别的概率。
如何分类?
采用分类函数 σ ( x ) = 1 1 + e ( − x ) \sigma(x)=\frac{1}{1+e^(-x)} σ(x)=1+e(x)1
在这里插入图片描述
逻辑回归:
在这里插入图片描述
二分类任务的损失函数:
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss=-(ylog\hat{y}+(1-y)log(1-\hat{y})) loss=(ylogy^+(1y)log(1y^))
二分类任务的小批量损失函数:
l o s s = − 1 N ∑ n = 1 N y n l o g y ^ n + ( 1 − y n ) l o g ( 1 − y ^ n ) loss=-\frac{1}{N}\sum_{n=1}^Ny_nlog\hat{y}_n+(1-y_n)log(1-\hat{y}_n) loss=N1n=1Nynlogy^n+(1yn)log(1y^n
在这里插入图片描述

import torch


#数据
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[0],[0],[1]])

#使用类设计模型
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
    
    def forward(self,x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = LogisticRegressionModel()

criterion = torch.nn.BCELoss(size_average=False,reduction='sum')#损失
optimizer = torch.optim.SGD(model.parameters(),lr=0.7)#优化器

for epoch in range(100):#训练
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


结果:
在这里插入图片描述

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view((200,1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x,y)
plt.plot([0,10],[0.5,0.5],c='r')
plt.xlabel('Hours')
plt.ylabel('Probability')
plt.grid()
plt.show()

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

1100dp

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值