Logistic回归与梯度下降法

转自:http://blog.csdn.net/acdreamers/article/details/44657979

Logistic回归为概率型非线性回归模型,是研究二分类观察结果与一些影响因素之间关系的一种

分析方法。通常的问题是,研究某些因素条件下某个结果是否发生,比如医学中根据病人的一些症状来判断它是

否患某种病。

 

在讲解Logistic回归理论之前,我们先从LR分类器说起。LR分类器,即Logistic Regression Classifier。

在分类情形下,经过学习后的LR分类器是一组权值,当测试样本的数据输入时,这组权值与测试数

据按线性加和得到

 

           

 

这里是每个样本的个特征。之后按照Sigmoid函数(又称为Logistic函数)的形式求出

 

           

 

由于Sigmoid函数的定义域为,值域为,因此最基本的LR分类器适合对两类目标进行分类。

所以Logistic回归最关键的问题就是研究如何求得这组权值。此问题用极大似然估计来做。

 

 

下面正式地来讲Logistic回归模型

 

考虑具有个独立变量的向量,设条件慨率为根据观测量相对于某事件发生

的概率。那么Logistic回归模型可以表示为

 

           

 

其中那么在条件下不发生的概率为

 

           

 

所以事件发生与不发生的概率之比为

 

           

 

这个比值称为事件的发生比(the odds of experiencing an event),简记为odds

 

可以看出Logistic回归都是围绕一个Logistic函数来展开的。接下来就讲如何用极大似然估计求分类器的参数。

 

假设有个观测样本,观测值分别为,设为给定条件下得到的概率,

同样地,的概率为,所以得到一个观测值的概率为

 

因为各个观测样本之间相互独立,那么它们的联合分布为各边缘分布的乘积。得到似然函数为

 

                                         

 

然后我们的目标是求出使这一似然函数的值最大的参数估计,最大似然估计就是求出参数,使

取得最大值,对函数取对数得到

 

            

 

现在求向量,使得最大,其中

 

这里介绍一种方法,叫做梯度下降法(求局部极小值),当然相对还有梯度上升(求局部极大值)。

对上述的似然函数求偏导后得到

 

            

 

由于是求局部极大值,所以根据梯度上升法,有

 

                    

 

根据上述公式,只需初始化向量全为零,或者随机值,迭代到指定精度为止。

 

现在就来用C++编程实现Logistic回归的梯度上升算法。首先要对训练数据进行处理,假设训练数据如下

 

训练数据:TrainData.txt

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. 1 0 0 1 0 1  
  2. 0 0 1 2 0 0  
  3. 1 0 0 1 1 0  
  4. 0 0 0 0 1 0  
  5. 0 0 1 0 0 0  
  6. 0 0 1 0 1 0  
  7. 0 0 1 2 1 0  
  8. 1 0 0 0 0 0  
  9. 0 0 1 0 1 0  
  10. 1 0 1 0 0 0  
  11. 0 0 1 0 1 0  
  12. 0 0 1 0 0 0  
  13. 0 0 1 0 1 0  
  14. 1 0 0 1 0 0  
  15. 1 0 0 0 1 0  
  16. 2 0 0 0 1 0  
  17. 1 0 0 2 1 0  
  18. 2 0 0 0 1 0  
  19. 2 0 1 0 0 0  
  20. 0 0 1 0 1 0  
  21. 0 0 1 2 0 0  
  22. 0 0 0 0 0 0  
  23. 0 0 1 0 1 0  
  24. 1 0 1 0 1 1  
  25. 0 0 1 2 1 0  
  26. 1 0 1 0 0 0  
  27. 0 0 1 0 0 0  
  28. 0 0 0 2 0 0  
  29. 1 0 0 0 1 0  
  30. 2 0 1 0 0 0  
  31. 2 0 1 1 1 0  
  32. 1 0 1 1 0 0  
  33. 1 0 1 2 0 0  
  34. 1 0 0 1 1 0  
  35. 0 0 0 0 1 0  
  36. 1 1 0 0 1 0  
  37. 1 0 1 2 1 0  
  38. 0 0 0 0 1 0  
  39. 0 0 1 0 0 0  
  40. 1 0 1 1 1 0  
  41. 1 0 1 0 1 0  
  42. 2 0 1 2 0 0  
  43. 0 0 1 2 1 0  
  44. 0 0 1 0 1 0  
  45. 2 0 1 0 1 0  
  46. 0 0 1 0 1 0  
  47. 1 0 0 0 0 0  
  48. 1 0 0 0 1 0  
  49. 0 0 0 0 1 0  
  50. 0 0 1 2 1 0  
  51. 0 1 1 0 0 0  
  52. 0 1 0 0 1 0  
  53. 2 1 0 0 0 0  
  54. 2 1 0 0 0 0  
  55. 1 1 0 2 0 0  
  56. 1 1 0 0 0 1  
  57. 0 1 0 0 0 0  
  58. 2 1 0 0 1 0  
  59. 0 1 0 0 1 0  
  60. 2 1 0 2 1 0  
  61. 2 1 0 2 1 0  
  62. 1 1 0 2 1 0  
  63. 0 1 0 0 0 1  
  64. 2 1 1 0 1 0  
  65. 2 1 0 1 1 0  
  66. 1 1 0 0 0 1  
  67. 2 1 0 0 0 0  
  68. 1 1 0 0 1 0  
  69. 1 1 0 0 0 0  
  70. 2 1 0 1 1 0  
  71. 1 1 0 0 1 0  
  72. 1 0 1 1 0 1  
  73. 2 1 0 1 1 0  
  74. 0 1 0 0 1 0  
  75. 1 0 1 0 0 0  
  76. 0 0 1 0 0 1  
  77. 1 0 0 0 0 0  
  78. 0 0 0 2 1 0  
  79. 1 0 1 2 0 1  
  80. 1 0 0 1 1 0  
  81. 2 0 1 2 1 0  
  82. 2 0 0 0 1 0  
  83. 1 0 0 1 1 0  
  84. 1 0 1 0 1 0  
  85. 0 0 1 0 0 0  
  86. 1 0 0 2 1 0  
  87. 2 0 1 1 1 0  
  88. 0 0 1 0 1 0  
  89. 0 0 0 0 1 0  
  90. 2 0 0 1 0 1  
  91. 0 0 1 0 0 0  
  92. 0 0 0 0 0 0  
  93. 1 0 1 1 1 1  
  94. 2 0 1 0 1 0  
  95. 0 0 0 0 0 0  
  96. 1 0 1 0 1 0  
  97. 0 0 0 0 1 0  
  98. 0 0 0 2 0 0  
  99. 0 0 0 0 0 0  
  100. 0 0 1 2 0 0  
  101. 0 0 1 0 1 0  
  102. 0 0 1 0 0 1  
  103. 0 0 0 2 1 0  
  104. 1 0 1 1 1 0  
  105. 1 0 0 1 1 0  
  106. 0 0 1 0 1 0  
  107. 1 0 0 0 0 0  
  108. 1 0 1 0 1 0  
  109. 2 0 0 0 1 0  
  110. 1 0 0 0 1 0  
  111. 2 0 0 1 1 0  
  112. 0 0 1 2 1 0  
  113. 1 0 1 2 0 0  
  114. 0 0 1 2 1 0  
  115. 1 0 0 0 0 0  
  116. 0 0 1 0 1 0  
  117. 0 0 0 1 1 0  
  118. 1 0 0 0 1 0  
  119. 2 0 0 1 1 0  
  120. 1 0 0 1 1 0  
  121. 1 0 1 0 0 0  
  122. 1 1 0 1 1 0  
  123. 2 1 0 0 1 0  
  124. 0 1 0 0 0 0  
  125. 1 1 0 1 0 1  
  126. 1 1 0 2 1 0  
  127. 0 1 0 0 0 0  
  128. 1 1 0 2 0 0  
  129. 0 1 0 0 1 0  
  130. 1 1 0 0 1 1  
  131. 1 1 0 2 1 0  
  132. 1 0 0 2 1 0  
  133. 2 1 1 1 1 0  
  134. 0 1 0 0 1 0  
  135. 0 1 0 0 1 0  
  136. 2 1 0 0 0 1  
  137. 1 1 0 2 1 0  
  138. 1 1 0 0 1 0  
  139. 1 1 1 0 0 0  
  140. 2 1 0 2 1 0  
  141. 2 1 1 1 0 0  
  142. 0 1 0 0 1 0  
  143. 1 1 0 2 1 0  
  144. 0 1 0 0 1 0  
  145. 1 1 0 1 1 0  
  146. 0 1 0 0 1 0  
  147. 0 1 0 0 0 0  
  148. 1 1 0 0 0 0  
  149. 1 1 0 2 1 0  
  150. 1 1 0 0 0 0  
  151. 0 1 1 2 0 0  
  152. 2 1 0 0 1 0  
  153. 2 0 1 0 0 1  
  154. 0 0 1 0 1 0  
  155. 1 0 1 0 0 0  
  156. 0 0 1 2 1 0  
  157. 0 0 1 0 0 0  
  158. 1 0 1 0 1 0  
  159. 0 0 1 0 1 0  
  160. 0 0 1 0 1 0  
  161. 1 0 1 0 1 0  
  162. 0 0 0 0 0 1  
  163. 0 0 1 2 1 0  
  164. 0 0 1 0 1 0  
  165. 0 0 1 0 1 0  
  166. 0 0 1 0 0 0  
  167. 0 0 1 0 0 1  
  168. 0 0 1 2 1 0  
  169. 2 0 1 2 1 0  
  170. 0 0 1 0 1 0  
  171. 0 0 1 0 1 0  
  172. 0 0 1 0 1 0  
  173. 1 0 0 0 0 0  
  174. 2 0 1 1 1 0  
  175. 0 0 1 0 0 1  
  176. 1 0 1 0 0 0  
  177. 1 0 1 1 1 0  
  178. 1 0 1 1 0 0  
  179. 0 0 1 0 0 0  
  180. 1 0 1 1 1 0  
  181. 1 0 1 2 0 0  
  182. 2 0 0 0 1 0  
  183. 0 0 1 0 0 1  
  184. 0 0 1 0 1 0  
  185. 0 0 1 0 1 0  
  186. 1 0 1 0 0 0  
  187. 0 0 1 0 0 0  
  188. 2 0 1 1 0 0  
  189. 0 0 1 2 0 0  
  190. 1 0 0 1 1 1  
  191. 0 0 0 0 1 0  
  192. 0 0 0 0 0 1  
  193. 0 0 1 0 1 0  
  194. 2 0 1 2 1 0  
  195. 1 0 0 1 0 0  
  196. 0 0 1 0 0 0  
  197. 2 0 0 1 1 1  
  198. 0 0 1 0 0 0  
  199. 0 0 1 0 1 0  
  200. 2 0 1 0 1 0  
  201. 0 0 1 0 1 0  
  202. 2 0 0 0 1 0  
  203. 1 0 1 0 1 0  
  204. 1 0 0 0 1 0  
  205. 0 0 1 0 0 1  
  206. 2 0 0 0 0 0  
  207. 2 0 0 1 1 0  
  208. 0 0 1 0 1 0  
  209. 0 0 0 0 1 0  
  210. 2 0 1 0 0 0  
  211. 1 0 1 0 1 0  
  212. 0 0 0 0 1 0  
  213. 1 0 1 0 1 0  
  214. 0 0 1 0 0 0  
  215. 1 0 1 0 1 0  
  216. 1 0 1 0 1 0  
  217. 1 0 1 0 1 0  
  218. 0 0 1 2 0 0  
  219. 2 0 1 0 1 1  
  220. 0 0 1 0 1 0  
  221. 0 0 1 2 1 0  
  222. 0 0 0 0 0 0  
  223. 0 0 1 0 1 0  
  224. 1 0 1 0 1 0  
  225. 0 0 1 0 1 0  
  226. 1 0 1 0 0 0  
  227. 0 0 1 0 1 0  
  228. 0 0 1 0 0 0  
  229. 1 0 1 0 0 0  
  230. 0 0 1 0 1 0  
  231. 0 0 1 0 1 0  
  232. 1 0 0 0 1 0  
  233. 0 0 1 0 0 0  
  234. 0 0 0 0 1 0  
  235. 1 0 1 1 1 0  
  236. 0 0 0 2 0 0  
  237. 0 0 1 0 1 0  
  238. 0 0 1 0 1 0  
  239. 0 0 1 0 1 0  
  240. 1 0 0 1 1 0  
  241. 2 0 0 0 1 0  
  242. 1 0 0 0 0 0  
  243. 2 0 0 2 1 0  
  244. 0 0 1 2 1 0  
  245. 1 0 1 0 0 1  
  246. 0 0 1 2 1 0  
  247. 0 0 1 2 1 0  
  248. 0 0 1 0 1 0  
  249. 1 0 1 2 1 0  
  250. 0 0 0 2 0 0  
  251. 1 0 0 0 0 0  
  252. 0 0 0 2 1 0  
  253. 0 0 1 0 1 0  
  254. 2 0 0 0 1 0  
  255. 1 0 0 0 0 0  
  256. 1 0 0 1 1 0  
  257. 1 0 1 1 1 0  
  258. 1 0 1 0 1 1  
  259. 0 0 1 0 1 0  
  260. 1 1 0 2 1 0  
  261. 1 1 0 1 0 0  
  262. 2 1 0 2 1 0  
  263. 1 1 1 0 0 0  
  264. 0 1 1 0 0 0  
  265. 0 1 1 0 0 1  
  266. 0 1 0 0 1 0  
  267. 1 1 1 0 0 0  
  268. 1 1 1 0 1 0  
  269. 0 1 0 0 1 0  
  270. 0 1 1 0 0 1  
  271. 1 1 1 1 1 0  
  272. 1 1 0 2 1 0  
  273. 0 1 0 2 0 0  
  274. 1 1 0 2 1 0  
  275. 0 0 1 2 1 0  
  276. 2 1 1 1 1 0  
  277. 0 1 0 0 1 0  
  278. 0 0 1 0 1 0  
  279. 2 1 0 1 1 0  
  280. 0 1 0 0 1 0  
  281. 1 1 0 0 0 0  
  282. 1 1 0 0 1 0  
  283. 0 1 0 0 0 0  
  284. 0 1 1 0 0 0  
  285. 2 1 0 0 1 0  
  286. 2 1 0 0 0 0  
  287. 1 1 0 0 1 0  
  288. 2 1 0 1 1 0  

上面训练数据中,每一行代表一组训练数据,每组有7个数组,第1个数字代表ID,可以忽略之,2~6代表这组训

练数据的特征输入,第7个数字代表输出,为0或者1。每个数据之间用一个空格隔开。

 

首先我们来研究如何一行一行读取文本,在C++中,读取文本的一行用getline()函数。

getline()函数表示读取文本的一行,返回的是读取的字节数,如果读取失败则返回-1。用法如下:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #include <iostream>  
  2. #include <string.h>  
  3. #include <fstream>  
  4. #include <string>  
  5. #include <stdio.h>  
  6.    
  7. using namespace std;  
  8.    
  9. int main()  
  10. {  
  11.     string filename = "data.in";  
  12.     ifstream file(filename.c_str());  
  13.     char s[1024];  
  14.     if(file.is_open())  
  15.     {  
  16.         while(file.getline(s,1024))  
  17.         {  
  18.             int x,y,z;  
  19.             sscanf(s,"%d %d %d",&x,&y,&z);  
  20.             cout<<x<<" "<<y<<" "<<z<<endl;  
  21.         }  
  22.     }  
  23.     return 0;  
  24. }  


拿到每一行后,可以把它们提取出来,进行系统输入。 Logistic回归的梯度上升算法实现如下

 

代码:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #include <iostream>  
  2. #include <string.h>  
  3. #include <fstream>  
  4. #include <stdio.h>  
  5. #include <math.h>  
  6. #include <vector>  
  7.    
  8. #define Type double  
  9. #define Vector vector  
  10. using namespace std;  
  11.    
  12. struct Data  
  13. {  
  14.     Vector<Type> x;  
  15.     Type y;  
  16. };  
  17.    
  18. void PreProcessData(Vector<Data>& data, string path)  
  19. {  
  20.     string filename = path;  
  21.     ifstream file(filename.c_str());  
  22.     char s[1024];  
  23.     if(file.is_open())  
  24.     {  
  25.         while(file.getline(s, 1024))  
  26.         {  
  27.             Data tmp;  
  28.             Type x1, x2, x3, x4, x5, x6, x7;  
  29.             sscanf(s,"%lf %lf %lf %lf %lf %lf %lf", &x1, &x2, &x3, &x4, &x5, &x6, &x7);  
  30.             tmp.x.push_back(1);  
  31.             tmp.x.push_back(x2);  
  32.             tmp.x.push_back(x3);  
  33.             tmp.x.push_back(x4);  
  34.             tmp.x.push_back(x5);  
  35.             tmp.x.push_back(x6);  
  36.             tmp.y = x7;  
  37.             data.push_back(tmp);  
  38.         }  
  39.     }  
  40. }  
  41.    
  42. void Init(Vector<Data> &data, Vector<Type> &w)  
  43. {  
  44.     w.clear();  
  45.     data.clear();  
  46.     PreProcessData(data, "TrainData.txt");  
  47.     for(int i = 0; i < data[0].x.size(); i++)  
  48.         w.push_back(0);  
  49. }  
  50.    
  51. Type WX(const Data& data, const Vector<Type>& w)  
  52. {  
  53.     Type ans = 0;  
  54.     for(int i = 0; i < w.size(); i++)  
  55.         ans += w[i] * data.x[i];  
  56.     return ans;  
  57. }  
  58.    
  59. Type Sigmoid(const Data& data, const Vector<Type>& w)  
  60. {  
  61.     Type x = WX(data, w);  
  62.     Type ans = exp(x) / (1 + exp(x));  
  63.     return ans;  
  64. }  
  65.    
  66. Type Lw(const Vector<Data>& data, Vector<Type> w)  
  67. {  
  68.     Type ans = 0;  
  69.     for(int i = 0; i < data.size(); i++)  
  70.     {  
  71.         Type x = WX(data[i], w);  
  72.         ans += data[i].y * x - log(1 + exp(x));  
  73.     }  
  74.     return ans;  
  75. }  
  76.    
  77. void Gradient(const Vector<Data>& data, Vector<Type> &w, Type alpha)  
  78. {  
  79.     for(int i = 0; i < w.size(); i++)  
  80.     {  
  81.         Type tmp = 0;  
  82.         for(int j = 0; j < data.size(); j++)  
  83.             tmp += alpha * data[j].x[i] * (data[j].y - Sigmoid(data[j], w));  
  84.         w[i] += tmp;  
  85.     }  
  86. }  
  87.    
  88. void Display(int cnt, Type objLw, Type newLw, Vector<Type> w)  
  89. {  
  90.     cout<<"第"<<cnt<<"次迭代:  ojLw = "<<objLw<<"  两次迭代的目标差为: "<<(newLw - objLw)<<endl;  
  91.     cout<<"参数w为: ";  
  92.     for(int i = 0; i < w.size(); i++)  
  93.         cout<<w[i]<<" ";  
  94.     cout<<endl;  
  95.     cout<<endl;  
  96. }  
  97.    
  98. void Logistic(const Vector<Data>& data, Vector<Type> &w)  
  99. {  
  100.     int cnt = 0;  
  101.     Type alpha = 0.1;  
  102.     Type delta = 0.00001;  
  103.     Type objLw = Lw(data, w);  
  104.     Gradient(data, w, alpha);  
  105.     Type newLw = Lw(data, w);  
  106.     while(fabs(newLw - objLw) > delta)  
  107.     {  
  108.         objLw = newLw;  
  109.         Gradient(data, w, alpha);  
  110.         newLw = Lw(data, w);  
  111.         cnt++;  
  112.         Display(cnt,objLw,newLw, w);  
  113.     }  
  114. }  
  115.    
  116. void Separator(Vector<Type> w)  
  117. {  
  118.     Vector<Data> data;  
  119.     PreProcessData(data, "TestData.txt");  
  120.     cout<<"预测分类结果:"<<endl;  
  121.     for(int i = 0; i < data.size(); i++)  
  122.     {  
  123.         Type p0 = 0;  
  124.         Type p1 = 0;  
  125.         Type x = WX(data[i], w);  
  126.         p1 = exp(x) / (1 + exp(x));  
  127.         p0 = 1 - p1;  
  128.         cout<<"实例: ";  
  129.         for(int j = 0; j < data[i].x.size(); j++)  
  130.             cout<<data[i].x[j]<<" ";  
  131.         cout<<"所属类别为:";  
  132.         if(p1 >= p0) cout<<1<<endl;  
  133.         else cout<<0<<endl;  
  134.     }  
  135. }  
  136.    
  137. int main()  
  138. {  
  139.     Vector<Type> w;  
  140.     Vector<Data> data;  
  141.    
  142.     Init(data, w);  
  143.     Logistic(data, w);  
  144.     Separator(w);  
  145.     return 0;  
  146. }  

 

测试数据:TestData.txt

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. 10009 1 0 0 1 0 1  
  2. 10025 0 0 1 2 0 0  
  3. 20035 0 0 1 0 0 1  
  4. 20053 1 0 0 0 0 0  
  5. 30627 1 0 1 2 0 0  
  6. 30648 2 0 0 0 1 0  
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值