pytorch之逻辑回归

这篇博客通过PyTorch构建了一个逻辑回归模型,用于二分类问题。作者生成了模拟数据,然后选择了损失函数(BCELoss)和优化器(SGD),进行了模型训练,并在每个迭代周期中绘制了决策边界。当模型准确率超过99%时,训练结束。
摘要由CSDN通过智能技术生成
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(10)

#生成数据
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums,2)
x0 = torch.normal(mean_value * n_data,1) + bias
y0 = torch.zeros(sample_nums)
x1 = torch.normal(-mean_value * n_data,1) + bias
y1 = torch.ones(sample_nums)
train_x = torch.cat((x0,x1),0)
train_y = torch.cat((y0,y1),0)

#选择模型
class LR(nn.Module):
    def __init__(self):
        super(LR,self).__init__()
        self.features = nn.Linear(2,1)
        self.sigmoid = nn.Sigmoid()

    def forward(self,x):
        x = self.features(x)
        x = self.sigmoid(x)
        return x

lr_net = LR() #实例化逻辑回归模型

#选择损失函数
loss_fn = nn.BCELoss()

#选择优化器
lr = 0.01
optimizer = torch.optim.SGD(lr_net.parameters(),lr=lr,momentum=0.9)

#模型训练
for iteration in range(1000):

    #前向传播
    y_pred = lr_net(train_x)

    #计算loss
    loss = loss_fn(y_pred.squeeze(),train_y)

    #反向传播
    loss.backward()

    #更新参数
    optimizer.step()

    #绘图
    if iteration % 20 ==0:

        mask = y_pred.ge(0.5).float().squeeze() #以0.5为阈值进行分类
        correct = (mask == train_y).sum() #计算正确预测的样本数量
        acc = correct.item() / train_y.size(0) #计算分类预测率

        plt .scatter(x0.data.numpy()[:,0],x0.data.numpy()[:,1],c='r',label='class 0')
        plt .scatter(x1.data.numpy()[:,0],x1.data.numpy()[:,1],c='b',label='class 1')

        w0,w1 = lr_net.features.weight[0]
        w0,w1 = float(w0.item()),float(w1.item())
        plot_b = float(lr_net.features.bias[0].item())
        plot_x = np.arange(-6,6,0.1)
        plot_y = (-w0 * plot_x - plot_b) / w1

        plt.xlim(-5,7)
        plt.ylim(-7,7)
        plt.plot(plot_x,plot_y)

        plt.text(-5,5,'Loss=%.4f' % loss.data.numpy(),fontdict={'size':20,'color':'red'})
        plt.title("Iteration:{}\nw0:{:.2f} w1:{:.2f} b:{:.2f} accuracy:{:.2%}".format(iteration,w0,w1,plot_b,acc))
        plt.legend()

        plt.show()
        plt.pause(0.5)

        if acc > 0.99:
            break

在这里插入图片描述
在这里插入图片描述

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值