Logistic Regression算法

监督学习简介

在监督学习中,分类(Classification)算法和回归(Regression)算法是两类最重要的算法,两者之间最主要的区别是分类算法中的标签是离散的值,如广告点击问题中的标签为{+1,-1},分别表示广告的点击和未点击,而回归算法中的标签值是连续的值,如通过人的身高、性别、体重等信息预测人的年龄,因为年龄是连续的正整数,因此标签为 y ∈ N + , y ∈ [ 1 , 80 ] y\in {{N}^{+}},y\in [1,80] yN+,y[1,80]

前言

分类算法是根据样本的特征,将样本划分到指定的类别中。分类算法是一种监督的学习算法,在分类算法中,根据训练样本训练得到样本特征到样本标签之间的映射,也被称为假设函数,利用该映射得到新的样本的标签,达到将新的样本划分到不同类别的目的。
Logistic Regression算法是一种被广泛使用的分类算法,通过训练数据中的正负样本,学习样本特征到样本标签之间假设函数,Logistic Regression算法是典型的线性分类器,由于算法的复杂度低,容易实现等特点,在工业界得到了广泛的应用。如:利用Logistic Regression算法实现广告的点击率预估。

Logistic Regression模型

线性可分VS线性不可分

对于一个分类问题,通常可以分为线性可分与线性不可分两种。如果一个分类问题可以使用线性判别函数正确分类,则称该问题为线性可分,如图1.1所示;否则为线性不可分问题,如图1.2所示。
在这里插入图片描述在这里插入图片描述

Logistic Regression模型

Logistic Regression模型是广义线性模型的一种,属于线性的分类模型。对于线性可分问题,需要找到一条直线,将两个不同的类分开,这条直线也被称为超平面。超平面的线性函数表示为:
W x + b = 0 Wx+b=0 Wx+b=0
其中W为权重,b为偏置。在多维情况下,W和b均为向量。超平面将数据分为正负两个类别,此时,可以通过阈值函数,将样本映射到不同的类别中,常见的阈值函数有Sigmoid函数,其形式如下:
f ( x ) = 1 1 + e − x f(x)=\frac{1}{1+{{e}^{-x}}} f(x)=1+ex1
用python实现Sigmoid函数:

import numpy as np

def sig(x):
	'''Sigmoid函数
	input: x(mat):feature*w
	output: sigmoid(x)(mat):Sigmoid值
	'''
	return 1.0 / (1 + np.exp(-x))

对于输入向量X,其属于正例的概率为:
P ( y = 1 ∣ X , W , b ) = σ ( W X + b ) = 1 1 + e − ( W X + b ) P(y=1|X,W,b)=\sigma (WX+b)=\frac{1}{1+{{e}^{-(WX+b)}}} P(y=1X,W,b)=σ(WX+b)=1+e(WX+b)1
其中, σ \sigma σ表示的是Sigmoid函数。那么,对于输入向量X,其属于负例的概率为:
P ( y = 0 ∣ X , W , b ) = 1 − P ( y = 1 ∣ X , W , b ) = 1 − σ ( W X + b ) = e − ( W X + b ) 1 + e − ( W X + b ) P(y=0|X,W,b)=1-P(y=1|X,W,b)=1-\sigma (WX+b)=\frac{{{e}^{-(WX+b)}}}{1+{{e}^{-(WX+b)}}} P(y=0X,W,b)=1P(y=1X,W,b)=1σ(WX+b)=1+e(WX+b)e(WX+b)
对于Logistic Regression算法来说,需要求解的是超平面的参数,即W和b,那么,该如何求解呢?为此,首先必须定义损失函数。

损失函数

要求上述问题中的W和b,可以使用极大似然法对其进行估计。假设训练数据集有m个训练样本 { ( X 1 y ( 1 ) ) , ( X ( 2 ) , y ( 2 ) ) , . . . , ( X ( m ) , y ( m ) ) } \{({{X}^{1}}{{\text{y}}^{(1)}}),({{X}^{(2)}},{{y}^{(2)}}),...,({{X}^{(m)}},{{y}^{(m)}})\} {(X1y(1)),(X(2),y(2)),...,(X(m),y(m))},则其似然函数为:
L W , b = ∏ i = 1 m [ h W , b ( X ( i ) ) y ( i ) ( 1 − h w , b ( X ( i ) ) ) 1 − y ( i ) ] {{L}_{W,b}}=\prod\limits_{i=1}^{m}{[{{h}_{W,b}}{{({{X}^{(i)}})}^{{{y}^{(i)}}}}{{(1-{{h}_{w,b}}({{X}^{(i)}}))}^{1-{{y}^{(i)}}}}]} LW,b=i=1m[hW,b(X(i))y(i)(1hw,b(X(i)))1y(i)]
其中,假设函数 h W , b ( X ( i ) ) {{h}_{W,b}}({{X}^{(i)}}) hW,b(X(i))为:
h W , b ( X ( i ) ) = σ ( W X ( i ) + b ) {{h}_{W,b}}({{X}^{(i)}})=\sigma (W{{X}^{(i)}}+b) hW,b(X(i))=σ(WX(i)+b)
对于似然函数的极大值的求解,通常使用Log似然函数,在Logistic Regression算法中,通常是将负的Log似然函数作为其损失函数,此时,需要计算的是NLL(the negative log-likelihood)的极小值。损失函数 l W , b {{l}_{W,b}} lW,b为:
l W , b = − 1 m log ⁡ L W , b = − 1 m ∑ i = 1 m [ y ( i ) log ⁡ ( h W , b ( X ( i ) ) ) + ( 1 − y ( i ) ) log ⁡ ( 1 − h W , b ( X ( i ) ) ) ] {{l}_{W,b}}=-\frac{1}{m}\log {{L}_{W,b}}=-\frac{1}{m}\sum\limits_{i=1}^{m}{[{{y}^{(i)}}\log ({{h}_{W,b}}({{X}^{(i)}}))+(1-{{y}^{(i)}})\log (1-{{h}_{W,b}}({{X}^{(i)}}))]} lW,b=m1logLW,b=m1i=1m[y(i)log(hW,b(X(i)))+(1y(i))log(1hW,b(X(i)))]
此时,我们需要求解的问题是: min ⁡ W , b   l W , b \underset{W,b}{\mathop{\min }}\,{{l}_{W,b}} W,bminlW,b
为了求得损失函数 l W , b {{l}_{W,b}} lW,b的最小值,可以使用基于梯度的方法进行求解。

梯度下降法

梯度下降法简介

梯度下降法的含义是通过当前点的梯度方向寻找到新的迭代点,并从当前点移动到新的迭代点继续寻找新的迭代点,直到找到最优解。
梯度下降法的详细流程为:
第一步:随机选择一个初始点 w 0 w_0 w0
第二步:重复一下过程:
1.决定梯度下降的方向: d i = − ∂ ∂ w f ( w ) ∣ w i {{d}_{i}}=-\frac{\partial }{\partial w}f(w){{|}_{{{w}_{i}}}} di=wf(w)wi
2.选择步长 α \alpha α
3.更新: w i + 1 = w i + α ∙ d i {{w}_{i+1}}={{w}_{i}}+\alpha \bullet {{d}_{i}} wi+1=wi+αdi
第三步:直到满足终止条件
其具体的过程如图1.4所示:
在这里插入图片描述
在初始时,在点 w 0 w_0 w0处,选择下降的方向 d 0 d_0 d0,选择步长 α \alpha α,更新w的值,此时到达 w 1 w_1 w1处,判断是否满足终止的条件,发现并未到达最优解 w ∗ {{w}^{*}} w,重复上述过程,直至到达 w ∗ {{w}^{*}} w
在梯度下降法中,选择步长和下降的方向很重要,需要特别留意。

凸优化与非凸优化

简单来讲,凸优化问题是指只存在一个最优解的优化问题,即任何一个局部最优解即全局最优解。直观一点来说就是函数图像只有一个拐点,拐点所在的地方即为最优解。
反之,非凸优化是指在解空间中存在多个局部最优解,而全局最优解是其中的某一个局部最优解。
最小二乘、岭回归和Logistic回归的损失函数都是凸优化问题。

利用Python实现Logistic Regression

下述代码在python2.7.10运行通过。
首先,准备训练集data.txt,文件内容如下:

4.45925637575900	8.22541838354701	0
0.0432761720122110	6.30740040001402	0
6.99716180262699	9.31339338579386	0
4.75483224215432	9.26037784240288	0
8.66190392439652	9.76797698918454	0
7.17376551727269	8.69456339325210	0
0.134053879775005	1.96878052943837	0
2.95850718791494	5.80458392655308	0
0.162199197495798	2.59575596457315	0
3.99590517062991	8.83289511075255	0
6.13130341636350	9.18109248691241	0
4.13802790208159	9.60748524925563	0
8.90049077697510	9.99329318936273	0
4.15371012638716	9.53922527310308	0
3.47806488316686	7.80788315923614	0
5.04287107839320	9.38996803836429	0
2.56324459286653	7.83286351946551	0
5.42032358874179	8.77024851395462	0
2.78940677946772	5.84911155908821	0
1.69377595965126	3.42939148982086	0
2.57907558635543	5.85738793565177	0
4.82185528268354	9.96169885562949	0
0.253602746238408	7.45945161909674	0
8.14641315834860	9.88547372518201	0
1.07252390677591	9.64544522650273	0
3.96810711848012	9.37483872272884	0
7.78456478657554	9.21622730177576	0
5.68031802716968	9.55034658708916	0
0.188925610590359	1.86579459185882	0
3.39042984564948	5.22734216789434	0
0.306872061364665	8.10706884864857	0
2.94520568616781	8.89485671892082	0
1.56257145202601	7.59114877849288	0
7.88040432728020	9.72811231772191	0
2.06663143127242	5.55531983916915	0
3.26847684673349	5.80133323123041	0
3.03572486052262	4.55438209687869	0
4.06403686804154	7.30824174639464	0
0.261319587511897	6.82874924285753	0
0.535129082770840	2.96735234780109	0
6.16211768166533	8.73591445232541	0
0.0543225029100144	3.63143937000534	0
3.38950408283417	5.21140264842544	0
0.670802539239751	5.50842093397427	0
3.33354941448896	9.00586384747725	0
4.83782090456725	9.24046925519147	0
4.14342743933778	5.21158367718785	0
0.0448488462892471	2.53343482171120	0
3.29344955007619	8.39195386760030	0
1.43481273099994	4.54377111225283	0
5.80444602917862	7.72222238849776	0
2.89737803415351	4.84582798104369	0
3.48896826617082	9.42538199279707	0
7.98990181138566	9.38748992074748	0
6.07911967934092	7.81580715692381	0
8.54988937636567	9.83106545896296	0
1.86253147316586	3.64519173433558	0
5.09264649345917	7.16456405109382	0
0.640487340762152	2.96504627163533	0
0.445682665759531	7.27017831379406	0
7.03580231161425	9.62162716377990	0
2.38548233873891	9.31142472376713	0
5.47187479221741	6.52268275403238	0
3.09625873884070	7.24687725634908	0
5.64986319346389	8.14649426712568	0
7.29995079462111	8.54616550280286	0
8.77346437600257	9.96234073280813	0
3.06091044795100	7.72080944560689	0
2.73380551789568	4.29286766635781	0
4.69373224420759	9.24765856154850	0
6.87533960169376	9.88875131106959	0
3.00192545879843	6.33960532829528	0
1.35371591887987	4.90531471836317	0
4.98310945603791	8.19357999963722	0
1.44378940451651	3.32854209513079	0
3.58695390614628	9.08726843217394	0
1.66501084835760	6.33826894701403	0
1.13683031909428	8.93555055642300	0
6.89981119465722	9.08506017228389	0
3.50612434749800	7.19887972236541	0
7.98213686241980	9.90335214596799	0
4.48550059813567	7.87469534845766	0
8.18730647207947	9.65752968521276	0
6.98206593291099	9.31671668967771	0
4.22691515046792	6.27588464716913	0
5.42330764790338	7.08234539249773	0
1.77758743959901	9.00325034532830	0
1.13084327057428	7.21101548428525	0
3.93586389272097	8.03817933454068	0
6.63734173280782	8.35388742591753	0
0.386709419537340	8.58317856311490	0
3.32377361598771	8.07768523594907	0
8.06147926210387	9.31872691969549	0
8.98108629344805	9.99686700846425	0
7.12209318712481	9.35515066520718	0
4.89305717339869	7.48113528632786	0
7.40008878777077	9.35257115162952	0
7.03477188625162	9.97577769141571	0
7.09428420398105	9.71970901043399	0
4.37051546270230	9.41555568974276	0
4.00462374868525	0.616899669845488	1
5.44921361045783	0.138855328072746	1
8.44816973291296	1.96823424949569	1
7.09749399121559	4.84058301823207	1
7.13659478263630	3.99344574129641	1
3.13607555660068	1.01976990251471	1
9.42777338698193	2.03196575362135	1
2.88202491477532	0.512799144639290	1
7.98235678340026	2.31403273981939	1
6.42767231292936	0.998772381680681	1
1.78712965201534	0.243156456792515	1
3.07790563815491	1.88920490138202	1
9.43195911269671	0.268801081716839	1
6.34282338934720	0.234109280621259	1
4.82412559000009	1.99459231605986	1
8.56305479341248	4.72713371076877	1
3.29684212405129	2.07793293965495	1
7.90570542236263	3.88560220548909	1
9.07438180273557	2.97201368564586	1
4.03957906251732	1.88105978289717	1
1.78205895950614	0.293661208884570	1
2.68642960682293	1.36689620711673	1
7.50923375595256	4.39161107527268	1
9.22583275513275	6.75308704026472	1
5.93852172849132	0.0767525003841111	1
1.85217721700788	0.438665288257335	1
9.57234374408993	7.22701037061452	1
9.41511021370755	8.19137839466097	1
4.35273751059844	1.07589578304157	1
8.04923110762197	1.49599991038707	1
9.63236985254770	8.03011417611781	1
4.29156538673758	1.13530168347901	1
7.23413465556835	4.46740498567466	1
8.20309146793921	4.27626275005668	1
1.87457510864822	0.489332789968566	1
1.38780964339963	0.293816887566648	1
6.93682776203853	3.55888673905837	1
3.04295274833797	1.45808471523332	1
7.52052171782518	2.24448074758751	1
1.29475743274558	0.248662041648605	1
7.93270103161629	3.53322776478622	1
7.36266950886594	5.03397809680607	1
3.97602116402406	2.71911655949805	1
6.19034466536054	1.49434836516839	1
3.40178374234313	2.13801547822256	1
4.29491070615890	0.494287306651284	1
9.51694812792873	5.54313061426752	1
9.83993111985939	1.71347650375439	1
8.45868335411795	3.68311342408982	1
6.31700682494856	0.238683478937646	1
6.08012336783280	3.21848321902105	1
5.48289786541560	0.267286639106428	1
5.63646865791522	0.246636667516577	1
8.59268728497635	4.51516446747963	1
7.21916947499839	4.02521298309224	1
6.61095852079607	5.35205946687567	1
2.59249592920151	0.756284591960114	1
8.58163248547907	3.76779427699367	1
7.86164254606212	1.23425344739733	1
7.22459851512079	5.69192759004084	1
2.46391682231140	0.626697111247094	1
1.37666679614458	0.0850311192606557	1
6.99535329428774	5.22402041296117	1
4.10255509353404	0.0845494795236941	1
9.14754400713588	3.83800038341603	1
9.19417616894172	2.68752405465039	1
7.21993188858784	0.108525252131188	1
5.61112164928016	1.53003323224417	1
4.02501202704844	2.88388192100925	1
7.75454869096328	2.25550746659844	1
6.69549473855791	2.36673839888467	1
6.18849817990230	4.97836674652459	1
6.54263406218390	0.834598177291021	1
8.54443627754918	5.47850873429622	1
1.97961580331274	0.158062518938500	1
3.67309260569741	1.13227508530645	1
5.75572976460165	1.30828876313430	1
8.45160472860189	5.90751849062474	1
3.46373193769655	1.51690595784220	1
1.91395651414783	0.612317246772413	1
4.12656938065540	2.09322445393649	1
3.65772017039046	1.80838713249693	1
4.54395760030265	1.38858754767689	1
2.86266970800483	1.34954202841685	1
8.37374823952218	6.83156860507414	1
5.62707992632729	0.893689431466388	1
6.42610816769460	1.07534713660481	1
4.69353317658611	3.15086012845143	1
4.86173611501328	3.57464916591233	1
4.35769946480643	2.34450146003832	1
6.29762879066540	0.534351386077976	1
4.71155500188011	3.43210544532588	1
1.49995426186507	0.192006207383273	1
3.16576104712266	0.493379007110222	1
4.21408348419092	2.97014277918461	1
5.52248511695330	3.63263027130760	1
4.15244831176753	1.44597290703838	1
9.55986996363196	1.13832040773527	1
1.63276516895206	0.446783742774178	1
9.38532498107474	0.913169554364942	1

使用lr_train.py对训练集进行训练:

# coding:UTF-8
import numpy as np

def load_data(file_name):
    '''导入训练数据
    input:  file_name(string)训练数据的位置
    output: feature_data(mat)特征
            label_data(mat)标签
    '''
    f = open(file_name)  # 打开文件
    feature_data = []
    label_data = []
    for line in f.readlines():
        feature_tmp = []
        lable_tmp = []
        lines = line.strip().split("\t")
        feature_tmp.append(1)  # 偏置项
        for i in xrange(len(lines) - 1):
            feature_tmp.append(float(lines[i]))
        lable_tmp.append(float(lines[-1]))
        
        feature_data.append(feature_tmp)
        label_data.append(lable_tmp)
    f.close()  # 关闭文件
    return np.mat(feature_data), np.mat(label_data)

def sig(x):
    '''Sigmoid函数
    input:  x(mat):feature * w
    output: sigmoid(x)(mat):Sigmoid值
    '''
    return 1.0 / (1 + np.exp(-x))

def lr_train_bgd(feature, label, maxCycle, alpha):
    '''利用梯度下降法训练LR模型
    input:  feature(mat)特征
            label(mat)标签
            maxCycle(int)最大迭代次数
            alpha(float)学习率
    output: w(mat):权重
    '''
    n = np.shape(feature)[1]  # 特征个数
    w = np.mat(np.ones((n, 1)))  # 初始化权重
    i = 0
    while i <= maxCycle:  # 在最大迭代次数的范围内
        i += 1  # 当前的迭代次数
        h = sig(feature * w)  # 计算Sigmoid值
        err = label - h
        if i % 100 == 0:
            print "\t---------iter=" + str(i) + \
            " , train error rate= " + str(error_rate(h, label))
        w = w + alpha * feature.T * err  # 权重修正
    return w

def error_rate(h, label):
    '''计算当前的损失函数值
    input:  h(mat):预测值
            label(mat):实际值
    output: err/m(float):错误率
    '''
    m = np.shape(h)[0]
    
    sum_err = 0.0
    for i in xrange(m):
        if h[i, 0] > 0 and (1 - h[i, 0]) > 0:
            sum_err -= (label[i,0] * np.log(h[i,0]) + \
                        (1-label[i,0]) * np.log(1-h[i,0]))
        else:
            sum_err -= 0
    return sum_err / m

def save_model(file_name, w):
    '''保存最终的模型
    input:  file_name(string):模型保存的文件名
            w(mat):LR模型的权重
    '''
    m = np.shape(w)[0]
    f_w = open(file_name, "w")
    w_array = []
    for i in xrange(m):
        w_array.append(str(w[i, 0]))
    f_w.write("\t".join(w_array))
    f_w.close()           

if __name__ == "__main__":
    # 1、导入训练数据
    print "---------- 1.load data ------------"
    feature, label = load_data("data.txt")
    # 2、训练LR模型
    print "---------- 2.training ------------"
    w = lr_train_bgd(feature, label, 1000, 0.01)
    # 3、保存最终的模型
    print "---------- 3.save model ------------"
    save_model("weights", w)

准备测试集test_data:

7.33251317753861	9.84290929650398
1.14288134664155	9.31938382343869
5.69123321602869	7.01397166818584
2.50648396980344	7.05766788137925
8.61756151890868	9.98657202473272
1.41851773509793	9.77704969645800
8.61450253418651	9.80161361673675
7.20252421999920	8.45756147719461
3.79585154363648	9.56147516348640
7.12986596603599	9.92424540796407
5.90166629240928	7.01231298989034
7.64216375281899	9.91037363924651
6.10861639371996	9.29953378509465
6.68819221312425	8.59494569110640
5.89930101159801	7.43009940132321
6.35441479217648	7.43863129967550
2.49230686464801	3.79277610650803
0.874186031122628	8.56545115532597
6.25345760678235	8.12438477163678
8.55199843954519	9.56743033736205
3.94869923690758	6.87606576238586
6.88965109334102	9.56780033528095
1.68185344098941	6.26602106875295
4.01027580639810	8.23519946958995
6.38428347772265	9.35832990092658
2.48422569298721	7.91301493123820
5.89588203576457	7.40064804417771
1.07097913402539	6.02251810104346
8.63769562664473	9.76101886404358
5.26740975881800	7.10280802002263
6.76140353375088	8.33245855777541
4.55361346498628	8.66197879154849
8.01812927282219	9.96002944206410
4.92493976967423	6.48984272359645
1.34364605003152	4.31522038864128
7.56645530385296	8.93098017284231
7.32856343461935	8.73559997192947
8.36337260868505	9.58618186062654
1.76935725388087	4.58485493022287
5.54440208531975	8.17989804462944
3.16493556356697	9.01287414297507
5.26737682037452	8.31928790306938
8.25474297446829	9.46776651141527
6.81480206199649	9.46184932462711
3.42401262277821	7.59017892397541
0.682688606067573	2.13140854275735
4.77717797708075	9.06746251589252
8.40609615806265	9.48324795436671
5.11941294784973	7.94092419194072
0.107118625511173	4.10511031080441
1.45964077373918	8.44883153836142
2.80093537840324	7.07734644006141
1.49083856549803	7.01121814413782
2.36674156086130	7.70541726069665
6.20293052826007	9.29556250878087
4.05487438652248	5.46938162981209
2.06079271845137	9.39862998789417
1.37140217072301	8.67122777257986
4.84508191734051	9.98394006421844
0.703579758778653	5.37622471649251
0.959874931625260	9.69365580472953
0.0417080172066070	7.98358222061436
7.35572898588090	9.78409851002885
0.759922609598193	5.05416257751295
2.33883362565589	8.66822288329864
3.88272444717190	9.54275911938782
1.63662325472567	4.57910351557924
1.30985082346245	3.35623833816854
7.82362986876080	9.50557703028000
4.94874181652699	6.53599112906454
7.67728005949704	9.50008478600453
3.15857142803044	7.15668195476007
3.61627230376748	5.02525628581462
2.15924538198292	4.00283995494553
1.65517009454175	4.41758058093557
3.75540362175933	5.01582106720932
8.12444498923753	9.95165814730251
4.41777683221272	7.65964160679034
3.03947468839239	9.40426842177465
3.32322103008194	4.95449449273065
7.02226861489024	8.79306734474469
2.17522157322449	5.93183247074687
0.868090726515497	2.94128556851324
8.47845531697937	9.97712220268860
5.17687735570619	6.40542188001300
2.11301922035166	5.54521551252613
7.39074636178163	8.41553439986347
0.387214214920271	2.84268913849686
5.84203927460807	9.15278983040824
5.82971366822676	8.25927093139731
4.92308003057711	7.13115624031492
6.70223526366741	8.13640943396238
6.18097890028784	7.69830072034376
3.31636136841303	7.87215118881414
7.02204691636239	8.18250988105294
8.36447373871857	9.85745951718317
4.38112469162855	7.39430116436659
4.02105374486826	6.54635130132668
4.57657789843014	7.83593612424246
7.35864937490036	9.66324641785085
6.79886317174323	2.19550400558507
8.30422412454229	3.89187751988243
4.15654393219195	2.96399968284951
8.88348530343686	4.33714944383228
6.60227577401105	3.28878632645783
2.86968063459726	0.563234429967053
5.23831013665832	0.976880305804374
8.59877913425850	1.47997081966077
3.03329602875159	0.347099994341680
3.04897868034898	0.892737314784995
3.79992057985372	2.58538966294283
4.87186652196626	0.715584122601641
9.14392871811904	7.97900095502468
4.94982975813493	0.438902015446521
3.32258226320860	0.949285465202363
6.35406466607753	1.40389865382386
6.42558780443875	3.85876366464538
2.99572060615516	0.234332825339264
3.67008285896494	0.851164479782249
4.81750083742427	1.93874942698667
1.76964217381040	0.202017397699835
8.20913160492765	0.210652826478027
9.35968725530240	6.10533760636674
5.39748076423221	2.54405282747684
3.13555221794369	0.979895632720517
9.66779685358222	4.73960088840639
5.69022247723601	1.08622919814201
5.40007969528150	2.74591412260864
7.11221986779173	2.41747595922350
4.30692983690029	3.26718716457572
1.33964979615597	0.300647133549756
9.21958144875315	6.54429819711743
1.88841050790017	0.232649111467001
4.01821155966517	2.05156276027461
2.22897823619833	0.886372899104719
1.96085675446517	0.628167164249429
5.44756542975343	3.46488351223326
7.43533370560625	5.81574338379743
9.01830253897710	2.67942045419741
7.28871249101315	1.24396912792495
1.27486851674173	0.204522588292904
5.50020192031181	2.15974654118566
9.14250014260627	4.96583927175149
6.55899750629609	4.77763763389266
8.24940482076717	4.18088773553725
2.64630222473423	0.395000602784235
8.97860739768491	0.228779804972462
5.40911249661002	0.740409676547689
9.80812584677043	6.27750259684544
5.50424461739359	2.12189727534722
1.53656980821675	0.365925533818477
1.38188023750668	0.0272836109904681
5.69484858217855	0.454132824391398
8.36333698473662	6.01987473987129
7.50195633130158	0.974418562562929
6.93644727617476	3.07861153390469
9.75677099287476	5.68306987800605
8.20297517817161	3.26869363187115
4.89152353405116	3.21172805778414
1.75122833373023	0.100041834145903
2.56049751807105	0.610057470246340
8.48241768555163	6.01110793166852
1.54424061252904	0.217292293635713
5.74188247457466	1.97641409239304
6.91173901876337	3.71241461026804
3.62785671965543	1.13431742828523
1.13938413072417	0.137162866799778
2.50451568923190	0.159804157398042
4.35168766049983	0.664031005121227
5.40718874214422	1.49621154952786
9.56467418299954	7.88234406137552
1.47409297912713	0.349813342676726
3.42207483758700	1.02413950354846
5.93083811093360	4.64848345065932
4.75969693884996	3.69597934891562
3.71309453840859	1.90214720551986
6.99704966425983	3.23316818617885
7.28294968162278	4.18776134130548
2.60319208960304	0.205231672986662
9.99172355285225	1.53867332274632
1.29340738477475	0.164660163515074
8.93679850406629	5.31110955598668
2.71389940461959	0.632285848653224
5.14653343534370	4.07039458510250
2.40764457003907	1.20427203219359
6.80288083183079	2.18346279657764
2.71831325712673	0.735872795240678
5.33819854928671	0.523237125832879
6.30556736225553	1.20005397144010
4.46157211932470	2.01804940846530
3.26625510225082	0.658212637318821
6.55381795953901	1.47332188122579
8.41938640019952	7.29075946387086
7.57223913040838	2.26004190249210
6.25662399950607	0.566501191933395
9.15677335584760	7.17513606222610
8.35984503433578	1.91891766916066
6.34920625597898	0.120424501924355
4.82733388192721	1.19687959104710
2.45336269880575	0.259812107633635

使用lr_test.py对测试集进行预测:

# coding:UTF-8
import numpy as np
from lr_train import sig

def load_weight(w):
    '''导入LR模型
    input:  w(string)权重所在的文件位置
    output: np.mat(w)(mat)权重的矩阵
    '''
    f = open(w)
    w = []
    for line in f.readlines():
        lines = line.strip().split("\t")
        w_tmp = []
        for x in lines:
            w_tmp.append(float(x))
        w.append(w_tmp)    
    f.close()
    return np.mat(w)

def load_data(file_name, n):
    '''导入测试数据
    input:  file_name(string)测试集的位置
            n(int)特征的个数
    output: np.mat(feature_data)(mat)测试集的特征
    '''
    f = open(file_name)
    feature_data = []
    for line in f.readlines():
        feature_tmp = []
        lines = line.strip().split("\t")
        # print lines[2]
        if len(lines) <> n - 1:
            continue
        feature_tmp.append(1)
        for x in lines:
            # print x
            feature_tmp.append(float(x))
        feature_data.append(feature_tmp)
    f.close()
    return np.mat(feature_data)

def predict(data, w):
    '''对测试数据进行预测
    input:  data(mat)测试数据的特征
            w(mat)模型的参数
    output: h(mat)最终的预测结果
    '''
    h = sig(data * w.T)#sig
    m = np.shape(h)[0]
    for i in xrange(m):
        if h[i, 0] < 0.5:
            h[i, 0] = 0.0
        else:
            h[i, 0] = 1.0
    return h

def save_result(file_name, result):
    '''保存最终的预测结果
    input:  file_name(string):预测结果保存的文件名
            result(mat):预测的结果
    '''
    m = np.shape(result)[0]
    #输出预测结果到文件
    tmp = []
    for i in xrange(m):
        tmp.append(str(result[i, 0]))
    f_result = open(file_name, "w")
    f_result.write("\t".join(tmp))
    f_result.close()    

if __name__ == "__main__":
    # 1、导入LR模型
    print "---------- 1.load model ------------"
    w = load_weight("weights")
    n = np.shape(w)[1]
    # 2、导入测试数据
    print "---------- 2.load data ------------"
    testData = load_data("test_data", n)
    # 3、对测试数据进行预测
    print "---------- 3.get prediction ------------"
    h = predict(testData, w)#进行预测
    # 4、保存最终的预测结果
    print "---------- 4.save prediction ------------"
    save_result("result", h)

执行lr_train.py,输出内容为:
在这里插入图片描述

通过上述的训练,最终得到的Logistic Regression模型的权重为:
w 0 = 1.394177750874827 w_0=1.394177750874827 w0=1.394177750874827
w 1 = 4.527177129107415 w_1=4.527177129107415 w1=4.527177129107415
w 2 = − 4.793981623770908 w_2=-4.793981623770908 w2=4.793981623770908
最终的分割超平面如图1.10所示:
在这里插入图片描述
执行lr_test.py,最终的预测结果result文件的内容为:

0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	0.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0

参考

Python机器学习算法/赵志勇著.——北京:电子工业出版社,2017.7

>>下一篇 Softmax Regression算法
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值