BFGS 拟牛顿法 二分类优化问题

# -*- coding: utf-8 -*-
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
#from mpl_toolkits.mplot3d import Axes3D
def loadDataSet():
    dataMat = []; labelMat = []
    try:
        fr = open('H:\Downloads\BFGS-master\in3D.txt')
    except IOError:
        print("请检查您的路径")
    else:
        for line in fr.readlines():
            lineArr = line.strip().split(',')
            dataMat.append([1.0, float(lineArr[0]), float(lineArr[1]),float(lineArr[2])])#b0,x1,x2,x3
            labelMat.append(int(lineArr[3]))
            #dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])#b0,x1,x2
            #labelMat.append(int(lineArr[2]))
            
        fr.close()
        return np.mat(dataMat),np.mat(labelMat)

#归一化函数 0-1
def sigmoid(inX):
    return 1.0/(1+np.exp(-inX))



#Pick initial point x_0
def D2plotBestFit(weights):
    dataMat,labelMat=loadDataSet()
    labelMat=labelMat.tolist()[0]
    dataArr = np.array(dataMat)
    n = np.shape(dataArr)[0] 
    xcord1 = []; ycord1 = []
    xcord2 = []; ycord2 = []
    for i in range(n):
        if int(labelMat[i])== 1:
            xcord1.append(dataArr[i,1]); ycord1.append(dataArr[i,2])
        else:
            xcord2.append(dataArr[i,1]); ycord2.append(dataArr[i,2])
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
    ax.scatter(xcord2, ycord2, s=30, c='green')
    x = np.arange(20, 120, 10)
    y = (-weights[0]-weights[1]*x)/weights[2]
    y=y.tolist()
    ax.plot(x, y[0])
    plt.xlabel('X1'); plt.ylabel('X2');
    plt.show()
    
def D3plotBestFit(weights):
    dataMat,labelMat=loadDataSet()
    labelMat=labelMat.tolist()[0]
    dataArr = np.array(dataMat)
    n = np.shape(dataArr)[0] 
    xcord1 = []; ycord1 = [];zcord1 = []
    xcord2 = []; ycord2 = [];zcord2 = []
    for i in range(n):
        if int(labelMat[i])== 1:
            xcord1.append(dataArr[i,1]); ycord1.append(dataArr[i,2]); zcord1.append(dataArr[i,3])
        else:
            xcord2.append(dataArr[i,1]); ycord2.append(dataArr[i,2]); zcord2.append(dataArr[i,3])

    ax = plt.axes(projection='3d')
    ax.scatter3D(xcord1, ycord1,zcord1, s=30, c='red', marker='s')
    ax.scatter3D(xcord2, ycord2,zcord2, s=30, c='green')
    
    x = np.arange(-1, 4, 0.5)
    y = np.arange(-1, 4, 0.5)
    print(np.shape(weights),np.shape(x),np.shape(y));
    z = (-weights[0]-weights[1]*x-weights[2]*y)/weights[3]
    ax.plot_wireframe(x,y,z,color='r')
    
    ax.set_xlabel('X1 Label')
    ax.set_ylabel('X2 Label')
    ax.set_zlabel('X3 Label')
    plt.show()

def MinFlamda(x,y,xk,pk): #选取前2000次迭代cost最小的alpha
    c=float("inf")
    t=xk
    for k in range(1,2000):
            a=1.0/k**2
            xk = t + a * pk
            f= np.sum(np.dot(x.T,sigmoid(np.dot(x,xk))-y))
            if abs(f)>c:
                break
            c=abs(f)
            alpha=a
    return alpha

def BFGS(x,y, iter,error):#BFGS拟牛顿法
    n = np.shape(x)[1]#输入参数数量
    xk=np.full((n,1),2)#n行1列,wi值为2,
    y=np.mat(y).T
    Bk=np.eye(n,n)#n阶单位阵,保证是正定的
    gk = np.dot(x.T,sigmoid(np.dot(x,xk))-y)
    cost=[]
    for it in range(iter):
        pk = -1 * np.linalg.solve(Bk, gk)   #搜索方向
        rate=MinFlamda(x,y,xk,pk)
        xk = xk + rate * pk
        grad= np.dot(x.T,sigmoid(np.dot(x,xk))-y)
        delta_xk = rate * pk
        delta_yk = grad - gk
        Pk = delta_yk.dot(delta_yk.T) / (delta_yk.T.dot(delta_xk))
        Qk= Bk.dot(delta_xk).dot(delta_xk.T).dot(Bk) / (delta_xk.T.dot(Bk).dot(delta_xk)) * (-1)
        Bk += Pk + Qk
        gk = grad
        print(np.sum(gk))
        if abs(np.sum(gk))<error:
            break
        cost.append(np.sum(gk))
    return xk,cost



x,y=loadDataSet()
wi,y1=BFGS(x,y, iter=40,error=0.001)
D3plotBestFit(wi)


3,3,3,1
4,3,2,1
2,1,2,1
1,1,1,0
-1,0,1,0
2,-1,1,0




34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0
70.66150955499435,92.92713789364831,1
76.97878372747498,47.57596364975532,1
67.37202754570876,42.83843832029179,0
89.67677575072079,65.79936592745237,1
50.534788289883,48.85581152764205,0
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1

我放弃了,c++写不出来,就用py了

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值