二分类logistic回归——平面点分类问题的实现

二分类logistic回归——平面点分类问题的实现

概述

对于二分类问题,logistic回归的目标是希望找到一个决策边界,将两类区分开来

感知机模型

对于一个输入 x x x,如果存在样本点使得 h w ( x ) = ∑ i = 1 m w i x i + b > 0 h_w(x)=\sum_{i=1}^mw_ix_i+b>0 hw(x)=i=1mwixi+b>0,那么判定它的类别为1,否则判定它的类别为0

logistic回归

在感知机模型基础上进行了改进,通过分类概率 P ( Y = 1 ) P(Y=1) P(Y=1)与输入 x x x之间的关系判别类型。假设一个事件发生的概率为 P P P,不发生的概率为 1 − P 1-P 1P,那么定义该事件发生的几率为 P 1 − P \frac{P}{1-P} 1PP,定义 l o g i t logit logit函数为:
l o g i t ( p ) = l o g ( p 1 − p ) = w ∗ x + b logit(p)=log(\frac{p}{1-p})=w*x+b logit(p)=log(1pp)=wx+b
w ∗ x + b w*x+b wx+b的值越接近 + ∞ +\infty +,几率越接近1,当 w ∗ x + b w*x+b wx+b越接近 − ∞ -\infty ,几率越接近0,用这个函数来决定目标属于哪一类

于是,对于训练集数据 T = ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . T={(x_1,y_1),(x_2,y_2),...} T=(x1,y1),(x2,y2),...,假设 P ( Y = 1 ∣ x ) = p P(Y=1|x)=p P(Y=1x)=p,那么 P ( Y = 0 ∣ x ) = 1 − p P(Y=0|x)=1-p P(Y=0x)=1p,所以似然函数为:
∏ i = 1 n p y i [ 1 − p ] 1 − y i \prod_{i=1}^np^{y_i}[1-p]^{1-y_i} i=1npyi[1p]1yi
取对数之后得到

L ( w ) = ∑ i = 1 n [ y i l o g ( p ) + ( 1 − y i ) l o g ( 1 − p ) ] = ∑ i = 1 n [ y i l o g ( p 1 − p ) + l o g ( 1 − p ) ) ] = ∑ i = 1 n [ y i ( w ∗ x + b ) − l o g ( 1 + e w ∗ x + b ) ] L(w)=\sum_{i=1}^n[y_ilog(p)+(1-y_i)log(1-p)]\\=\sum_{i=1}^n[y_ilog(\frac{p}{1-p})+log(1-p))]\\=\sum_{i=1}^n[y_i(w*x+b)-log(1+e^{w*x+b})] L(w)=i=1n[yilog(p)+(1yi)log(1p)]=i=1n[yilog(1pp)+log(1p))]=i=1n[yi(wx+b)log(1+ewx+b)]

之后只需求 ∂ L ( w ) ∂ w \frac{\partial L(w)}{\partial w} wL(w) ∂ L ( w ) ∂ b \frac{\partial L(w)}{\partial b} bL(w)即可反向传播,得到一个网络,输入 x x x,输出 p p p

平面点分类问题

平面上有一些点,其中部分属于集合1,部分属于集合0,保证这些点的划分是线性的,现在需要找到一条直线,将属于不同集合的点完美划分在直线两边

分析

使用logistic回归,输入数值对 ( x i , y i ) (x_i,y_i) (xi,yi),输出属于集合1的概率p

如果 p > 0.5 p>0.5 p>0.5则判定该点属于集合1,否则判定该点属于集合0

代码实现

编译环境:Pytorch3.7

语言:python


首先创建data.txt文件记录各个点的数据:
6Vb3dO.png
其中第一列是x坐标,第二列是y坐标,第三列表示所属的集合是1还是0

之后运行如下程序:

import torch
from torch import nn
from matplotlib import pyplot as plt
from torch.autograd import Variable
import numpy as np

# 读取数据
with open('data.txt', 'r') as f:
    data_list = f.readlines()
    data_list = [i.split('\n')[0] for i in data_list]
    data_list = [i.split(',') for i in data_list]
    data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]
x_data = [(float(i[0]), float(i[1])) for i in data]
y_data = [(float(i[2])) for i in data]
x_data = torch.Tensor(x_data)
y_data = torch.Tensor(y_data)
y_data = y_data.view(10, 1)
x0 = list(filter(lambda x: x[-1] == 0.0, data))
x1 = list(filter(lambda x: x[-1] == 1.0, data))
plot_x0_0 = [i[0] for i in x0]
plot_x0_1 = [i[1] for i in x0]
plot_x1_0 = [i[0] for i in x1]
plot_x1_1 = [i[1] for i in x1]
plt.plot(plot_x0_0, plot_x0_1, 'ro', label='x_0')
plt.plot(plot_x1_0, plot_x1_1, 'bo', label='x_1')
plt.legend(loc='best')


# 定义logistic模型
class LogisticRegerssion(nn.Module):
    def __init__(self):
        super(LogisticRegerssion, self).__init__()
        self.mode = nn.Linear(2, 1)
        self.sm = nn.Sigmoid()

    def forward(self, x):
        x = self.mode(x)
        x = self.sm(x)
        return x

#构造网络、优化器、损失函数
logistic_model = LogisticRegerssion()
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(logistic_model.parameters(), lr=1e-3, momentum=0.9)

#开始训练
for epoch in range(50000):
    x = Variable(x_data)
    y = Variable(y_data)
    # forward
    out = logistic_model(x)
    loss = criterion(out, y)
    print_loss = loss.data
    mask = out.ge(0.5).float()
    correct = (mask == y).sum()
    acc = correct.data / x.size(0)
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 1000 == 0:
        print('epoch {} loss {} acc {}'.format(epoch + 1, print_loss, acc))

#取出参数,绘制直线
w0, w1 = logistic_model.mode.weight[0]
w0 = w0.data
w1 = w1.data
b = logistic_model.mode.bias.data[0]
plot_x = np.arange(0, 4, 0.1)
plot_y = (-w0 * plot_x - b) / w1
plt.plot(plot_x, plot_y)
plt.show()

得到结果:
6Vb1eK.png
可以看到找到了一条直线划分了属于不同集合的点。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值