利用梯度下降法进行多变量进行二分类

目录

核心

导入库并且创建数据集

一些参数和超参数得定义

开始训练

完整代码

运行结果


核心

本次二分类使用的算法比较简单,简而言之就是“同号得正,异号得负”,什么意思呢,就是我们的数据集肯定是处理好的,一般二分类问题,它的标签都是+1,-1,用这两个数来表示两个类别。

而我们只需要计算预测值和真实值的乘积,如果他们同号,就视为预测正确,此时,预测值*真实值>0,(不能=0)。

如果他们异号,则预测错误,此时预测值*真实值<=0。

因为我们模型最终训练的对象是参数,我们没必要非得让预测值=真实值,我们在这样做可以节约计算开销,提高模型的效率,我们只需要最后对预测值进行处理即可。

导入库并且创建数据集

import numpy as np
from matplotlib import pyplot as plt, font_manager

print('利用梯度下降法进行多变量进行二分类-------------------------------')
# 数据集中的data部份
x = np.array([[0.180, 0.001 * 1], [0.100, 0.001 * 2],
              [0.160, 0.001 * 3], [0.080, 0.001 * 4],
              [0.090, 0.001 * 5], [0.110, 0.001 * 6],
              [0.120, 0.001 * 7], [0.170, 0.00 * 8],
              [0.150, 0.001 * 9], [0.140, 0.001 * 10],
              [0.130, 0.001 * 11]])
# 数据集中的target部份
y = np.array([+1, -1, +1, -1, -1, -1, -1, +1, +1, +1, -1])

一些参数和超参数得定义

sample, w_fea = x.shape#循环得时候用
print(sample, w_fea)
# 参数
w = np.zeros(w_fea)
# 偏置
bias = 0
# 学习率:
Ir = 0.0001
# 循环次数
a = 0
# 预测值:
pre_y = 0

开始训练

# 开始训练
lista=[]
listw=[]
listb=[]
bath = 101
epoch = 101
for y1 in range(1, bath):
    for y2 in range(1, epoch):
        a = a + 1
        f = False
        for y3 in range(sample):
            pre_y = np.dot(x[y3], w) + bias
            if pre_y * y[y3] <= 0:
                f = True
                w = w + Ir * y[y3] * x[y3]
                bias = bias + Ir * y[y3]
        if not f:
            break
        print('循环轮数:{0},w:{1},bias:{2}'.format(a, w, bias))

这里面得关键就在于if pre_y*y[y3]<=0,这一块就体现了我们在核心实现的思想。

因为们的判断条件是预测值*真实值<=0进行参数更新,我们更新的目的是为了让预测值*真实值>0,相当于让预测值*真实值这个函数增加,也就是说我们要沿梯度的正方向移动,所以,我们就有了

关于为什么要这样,也就是梯度下降相关的看这里深度学习之梯度,梯度下降法以及使用梯度下降法实现线性回归代码实例(完整代码)​​​​​​​

 w = w + Ir * y[y3] * x[y3]
 bias = bias + Ir * y[y3]

完整代码

import torch
import numpy as np
from matplotlib import pyplot as plt, font_manager

print('利用梯度下降法进行多变量进行二分类-------------------------------')
# 数据集中的data部份
x = np.array([[0.180, 0.001 * 1], [0.100, 0.001 * 2],
              [0.160, 0.001 * 3], [0.080, 0.001 * 4],
              [0.090, 0.001 * 5], [0.110, 0.001 * 6],
              [0.120, 0.001 * 7], [0.170, 0.00 * 8],
              [0.150, 0.001 * 9], [0.140, 0.001 * 10],
              [0.130, 0.001 * 11]])
# 数据集中的target部份
y = np.array([+1, -1, +1, -1, -1, -1, -1, +1, +1, +1, -1])

sample, w_fea = x.shape
print(sample, w_fea)
# 参数
w = np.zeros(w_fea)
# 偏置
bias = 0
# 学习率:
Ir = 0.0001
# 循环次数
a = 0
# 预测值:
pre_y = 0
# 开始训练
lista=[]
listw=[]
listb=[]
bath = 101
epoch = 101
for y1 in range(1, bath):
    for y2 in range(1, epoch):
        a = a + 1
        f = False
        for y3 in range(sample):
            pre_y = np.dot(x[y3], w) + bias
            if pre_y * y[y3] <= 0:
                f = True
                w = w + Ir * y[y3] * x[y3]
                bias = bias + Ir * y[y3]
        if not f:
            break
        print('循环轮数:{0},w:{1},bias:{2}'.format(a, w, bias))
    lista.append(a)
    listw.append(w)
    listb.append(bias)
print('w:', w)
print('bias:', bias)
pre_y_ = []
for i in range(sample):
    pre_y = np.dot(x[i], w) + bias
    if pre_y > 0:
        pre_y_.append(1)
    else:
        pre_y_.append(-1)
print('预测值:', pre_y_)
print('真实值:', y)
# lis = []
# for x in range(1, 12):
#     lis.append(x)
# print(lis)
# print(pre_y_)
# font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
# plt.plot(lista,listw,"r",label='w')
# plt.plot(lista,listb,"b",label='b')
#
# plt.title('循环轮数和参数',fontproperties=font, fontsize=18)
# plt.show()
# 绘制决策边界和误分类点
x1 = np.arange(0.135, 0.145, 0.0001)
#print(x1)
x2 = (-w[0] * x1 - bias) / w[1]
plt.plot(x1, x2, 'r', label='Decision Boundary')
for ii in range(11):
    if y[ii] == +1:
        plt.plot([ x[ii][0], x[ii][1], 'bo'])
    else:
        plt.plot([x[ii][0], x[ii][1], 'ro'])
plt.legend()

plt.show()

运行结果

D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\学习过程\第二周的代码\代码二.py 
利用梯度下降法进行多变量进行二分类-------------------------------
11 2
循环轮数:1,w:[ 4.0e-06 -1.6e-06],bias:-0.0001
循环轮数:2,w:[ 2.4e-05 -2.9e-06],bias:-0.0001
循环轮数:3,w:[ 4.4e-05 -4.2e-06],bias:-0.0001
循环轮数:4,w:[ 6.4e-05 -5.5e-06],bias:-0.0001
循环轮数:5,w:[ 8.4e-05 -6.8e-06],bias:-0.0001
循环轮数:6,w:[ 1.04e-04 -8.10e-06],bias:-0.0001
循环轮数:7,w:[ 1.24e-04 -9.40e-06],bias:-0.0001
循环轮数:8,w:[ 1.44e-04 -1.07e-05],bias:-0.0001
循环轮数:9,w:[ 1.64e-04 -1.20e-05],bias:-0.0001
循环轮数:10,w:[ 1.84e-04 -1.33e-05],bias:-0.0001
循环轮数:11,w:[ 2.04e-04 -1.46e-05],bias:-0.0001
循环轮数:12,w:[ 2.24e-04 -1.59e-05],bias:-0.0001
循环轮数:13,w:[ 2.44e-04 -1.72e-05],bias:-0.0001
循环轮数:14,w:[ 2.64e-04 -1.85e-05],bias:-0.0001
循环轮数:15,w:[ 2.84e-04 -1.98e-05],bias:-0.0001
循环轮数:16,w:[ 3.04e-04 -2.11e-05],bias:-0.0001
循环轮数:17,w:[ 3.24e-04 -2.24e-05],bias:-0.0001
循环轮数:18,w:[ 3.44e-04 -2.37e-05],bias:-0.0001
循环轮数:19,w:[ 3.64e-04 -2.50e-05],bias:-0.0001
循环轮数:20,w:[ 3.84e-04 -2.63e-05],bias:-0.0001
循环轮数:21,w:[ 4.04e-04 -2.76e-05],bias:-0.0001
循环轮数:22,w:[ 4.24e-04 -2.89e-05],bias:-0.0001
循环轮数:23,w:[ 4.44e-04 -3.02e-05],bias:-0.0001
循环轮数:24,w:[ 4.64e-04 -3.15e-05],bias:-0.0001
循环轮数:25,w:[ 4.84e-04 -3.28e-05],bias:-0.0001
循环轮数:26,w:[ 5.04e-04 -3.41e-05],bias:-0.0001
循环轮数:27,w:[ 5.24e-04 -3.54e-05],bias:-0.0001
循环轮数:28,w:[ 5.44e-04 -3.67e-05],bias:-0.0001
循环轮数:29,w:[ 5.64e-04 -3.80e-05],bias:-0.0001
循环轮数:30,w:[ 5.76e-04 -3.92e-05],bias:-0.0001
循环轮数:31,w:[ 5.88e-04 -4.04e-05],bias:-0.0001
循环轮数:32,w:[ 5.98e-04 -4.07e-05],bias:-0.0001
循环轮数:33,w:[ 6.08e-04 -4.10e-05],bias:-0.0001
循环轮数:34,w:[ 6.18e-04 -4.13e-05],bias:-0.0001
循环轮数:35,w:[ 6.28e-04 -4.16e-05],bias:-0.0001
循环轮数:36,w:[ 6.30e-04 -4.18e-05],bias:-0.0001
循环轮数:37,w:[ 6.32e-04 -4.20e-05],bias:-0.0001
循环轮数:38,w:[ 6.34e-04 -4.22e-05],bias:-0.0001
循环轮数:39,w:[ 6.36e-04 -4.24e-05],bias:-0.0001
循环轮数:40,w:[ 6.38e-04 -4.26e-05],bias:-0.0001
循环轮数:41,w:[ 6.40e-04 -4.28e-05],bias:-0.0001
循环轮数:42,w:[ 6.42e-04 -4.30e-05],bias:-0.0001
循环轮数:43,w:[ 6.44e-04 -4.32e-05],bias:-0.0001
循环轮数:44,w:[ 6.46e-04 -4.34e-05],bias:-0.0001
循环轮数:45,w:[ 6.48e-04 -4.36e-05],bias:-0.0001
循环轮数:46,w:[ 6.50e-04 -4.38e-05],bias:-0.0001
循环轮数:47,w:[ 6.52e-04 -4.40e-05],bias:-0.0001
循环轮数:48,w:[ 6.54e-04 -4.42e-05],bias:-0.0001
循环轮数:49,w:[ 6.56e-04 -4.44e-05],bias:-0.0001
循环轮数:50,w:[ 6.58e-04 -4.46e-05],bias:-0.0001
循环轮数:51,w:[ 6.60e-04 -4.48e-05],bias:-0.0001
循环轮数:52,w:[ 6.62e-04 -4.50e-05],bias:-0.0001
循环轮数:53,w:[ 6.64e-04 -4.52e-05],bias:-0.0001
循环轮数:54,w:[ 6.66e-04 -4.54e-05],bias:-0.0001
循环轮数:55,w:[ 6.68e-04 -4.56e-05],bias:-0.0001
循环轮数:56,w:[ 6.70e-04 -4.58e-05],bias:-0.0001
循环轮数:57,w:[ 6.71e-04 -4.59e-05],bias:-0.0001
循环轮数:58,w:[ 6.72e-04 -4.60e-05],bias:-0.0001
循环轮数:59,w:[ 6.73e-04 -4.61e-05],bias:-0.0001
循环轮数:60,w:[ 6.74e-04 -4.62e-05],bias:-0.0001
循环轮数:61,w:[ 6.75e-04 -4.63e-05],bias:-0.0001
循环轮数:62,w:[ 6.76e-04 -4.64e-05],bias:-0.0001
循环轮数:63,w:[ 6.77e-04 -4.65e-05],bias:-0.0001
循环轮数:64,w:[ 6.78e-04 -4.66e-05],bias:-0.0001
循环轮数:65,w:[ 6.79e-04 -4.67e-05],bias:-0.0001
循环轮数:66,w:[ 6.80e-04 -4.68e-05],bias:-0.0001
循环轮数:67,w:[ 6.81e-04 -4.69e-05],bias:-0.0001
循环轮数:68,w:[ 6.82e-04 -4.70e-05],bias:-0.0001
循环轮数:69,w:[ 6.83e-04 -4.71e-05],bias:-0.0001
循环轮数:70,w:[ 6.84e-04 -4.72e-05],bias:-0.0001
循环轮数:71,w:[ 6.85e-04 -4.73e-05],bias:-0.0001
循环轮数:72,w:[ 6.86e-04 -4.74e-05],bias:-0.0001
循环轮数:73,w:[ 6.87e-04 -4.75e-05],bias:-0.0001
循环轮数:74,w:[ 6.88e-04 -4.76e-05],bias:-0.0001
循环轮数:75,w:[ 6.89e-04 -4.77e-05],bias:-0.0001
循环轮数:76,w:[ 6.90e-04 -4.78e-05],bias:-0.0001
循环轮数:77,w:[ 6.91e-04 -4.79e-05],bias:-0.0001
循环轮数:78,w:[ 6.92e-04 -4.80e-05],bias:-0.0001
循环轮数:79,w:[ 6.93e-04 -4.81e-05],bias:-0.0001
循环轮数:80,w:[ 6.94e-04 -4.82e-05],bias:-0.0001
循环轮数:81,w:[ 6.95e-04 -4.83e-05],bias:-0.0001
循环轮数:82,w:[ 6.96e-04 -4.84e-05],bias:-0.0001
循环轮数:83,w:[ 6.97e-04 -4.85e-05],bias:-0.0001
循环轮数:84,w:[ 6.98e-04 -4.86e-05],bias:-0.0001
循环轮数:85,w:[ 6.99e-04 -4.87e-05],bias:-0.0001
循环轮数:86,w:[ 7.00e-04 -4.88e-05],bias:-0.0001
循环轮数:87,w:[ 7.01e-04 -4.89e-05],bias:-0.0001
循环轮数:88,w:[ 7.02e-04 -4.90e-05],bias:-0.0001
循环轮数:89,w:[ 7.03e-04 -4.91e-05],bias:-0.0001
循环轮数:90,w:[ 7.04e-04 -4.92e-05],bias:-0.0001
循环轮数:91,w:[ 7.05e-04 -4.93e-05],bias:-0.0001
循环轮数:92,w:[ 7.06e-04 -4.94e-05],bias:-0.0001
循环轮数:93,w:[ 7.07e-04 -4.95e-05],bias:-0.0001
循环轮数:94,w:[ 7.08e-04 -4.96e-05],bias:-0.0001
循环轮数:95,w:[ 7.09e-04 -4.97e-05],bias:-0.0001
循环轮数:96,w:[ 7.10e-04 -4.98e-05],bias:-0.0001
循环轮数:97,w:[ 7.11e-04 -4.99e-05],bias:-0.0001
循环轮数:98,w:[ 7.12e-04 -5.00e-05],bias:-0.0001
循环轮数:99,w:[ 7.13e-04 -5.01e-05],bias:-0.0001
循环轮数:100,w:[ 7.14e-04 -5.02e-05],bias:-0.0001
循环轮数:101,w:[ 7.15e-04 -5.03e-05],bias:-0.0001
循环轮数:102,w:[ 7.16e-04 -5.04e-05],bias:-0.0001
循环轮数:103,w:[ 7.17e-04 -5.05e-05],bias:-0.0001
循环轮数:104,w:[ 7.18e-04 -5.06e-05],bias:-0.0001
w: [ 7.18e-04 -5.06e-05]
bias: -0.0001
预测值: [1, -1, 1, -1, -1, -1, -1, 1, 1, 1, -1]
真实值: [ 1 -1  1 -1 -1 -1 -1  1  1  1 -1]

进程已结束,退出代码0

我的图像是有问题的,绘制决策边界和误分类点,不应该是这样的,大家看运行结果就好

  • 9
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值