线性规划的代码实现

线性规划其实实现很简单, 关键就是theta的训练。下面是我的JAVA代码实现:

我用的训练集为:

1.0 2.0 1.5
2.0 3.5 3.4
-1.2 2.0 3.5
4.7 3.2 4.5
2.3 -2.5 5.4

下面是类与函数的实现:

  1 import java.io.BufferedReader;
  2 import java.io.File;
  3 import java.io.FileReader;
  4 import java.io.IOException;
  5 
  6 /**
  7  * 线性回归
  8  * @author CassieRyu
  9  *批量梯度下降法
 10  */
 11 public class Linear {
 12     
 13     private double[][] trainData;//训练数据集
 14     private int row; //训练集的行(样本数目)
 15     private int column; //训练集的列(特征数目+2)第一行是人为添加的x0, 最后一列为y值
 16     
 17     private double[] theta; //参数theta向量
 18     private double alpha; //步长
 19     private int iteration; //迭代次数
 20 
 21     //构造函数
 22     public Linear(String fileName, double alpha, int ite){
 23         int rowF= getRowFromFile(fileName);
 24         int columnF = getColumnFromFile(fileName);
 25         
 26         row = rowF;
 27         column = columnF+1; //为了计算方便,加上x0那行
 28         
 29         trainData = new double[row][column];
 30         loadTrainData(fileName);
 31         
 32         this.alpha = alpha;
 33         this.iteration = ite;
 34         
 35         theta = new double[column-1];//减去y对应的那行
 36         initializeTheta();//theta的初始化
 37         
 38         trainedTheta();//训练后的theta值
 39     }
 40     
 41     //返回训练集的样本数目row
 42     public int getRowFromFile(String fileName){
 43         
 44         int count=0;
 45         File file = new File(fileName);
 46         BufferedReader br = null;
 47         try{
 48             br = new BufferedReader(new FileReader(file));
 49             String temp = null;
 50             while((temp = br.readLine())!=null){ //循环读取下一行
 51                 count++;
 52             }
 53         }catch(IOException e){
 54             e.printStackTrace();
 55         }finally{
 56             if(br!=null)
 57                 try{
 58                     br.close();
 59                 }catch(IOException e1){
 60                     
 61                 }
 62         }
 63         
 64         return count;
 65     }
 66     
 67     //返回训练集列数,不包含x0
 68     public int getColumnFromFile(String fileName){
 69         
 70         int count=0;
 71         File file = new File(fileName);
 72         BufferedReader br = null;
 73         try{
 74             br = new BufferedReader(new FileReader(file));
 75             String temp = null;
 76             if((temp = br.readLine())!=null){
 77                 String [] tempStr = temp.split(" ");//用空格将列分开
 78                 count = tempStr.length;//数组长度为列的数目
 79             }
 80         }catch(IOException e){
 81             e.printStackTrace();
 82         }finally{
 83             if(br!=null)
 84                 try{
 85                     br.close();
 86                 }catch(IOException e1){
 87                     
 88                 }
 89         }
 90         
 91         return count;
 92     }
 93     
 94     //返回训练集
 95     public void loadTrainData(String fileName){
 96         
 97         //初始化x0为1
 98         for(int i=0;i<row;i++)
 99             trainData[i][0]=1.0;
100         
101         File file = new File(fileName);
102         BufferedReader br = null;
103         try{
104             br = new BufferedReader(new FileReader(file));
105             String temp = null;
106             int count=0;
107             while((temp = br.readLine())!=null){ //行循环
108                 String [] tempStr = temp.split(" ");//用空格将列分开
109                 
110                 for(int i=1;i<column;i++) //对每行的每列赋值,除第一列x0==1已赋值
111                     trainData[count][i] = Double.parseDouble(tempStr[i-1]);  
112                 count++; //行号加1
113             }
114         }catch(IOException e){
115             e.printStackTrace();
116         }finally{
117             if(br!=null)
118                 try{
119                     br.close();
120                 }catch(IOException e1){
121                     
122                 }
123         }
124     }
125     
126     //初始化theta的值
127     public void initializeTheta(){
128         
129         for(int i=0;i<column-1;i++)
130             theta[i]=1.0;
131     }
132     
133     //训练theta的值
134     public void trainedTheta(){
135         
136         while((iteration--)>0){//迭代次数
137             
138             //每迭代一次需带入新的theta值重新计算一次h(xi)-y(i)
139             double[] temp = new double[row];
140             temp=getDerivation(); //h(xi)-y(i)
141             
142             for(int j=0;j<column-1;j++){//循环一次的复杂度为O(m),m为样本数目
143                 double []tep = new double[row];
144                 double result=0.0;
145                 for(int i=0;i<row;i++){
146                     tep[i] = temp[i]*trainData[i][j]; //(h(xi)-y(i))*X(ij)
147                     result+=tep[i];
148                 }
149                 theta[j]-= alpha*result;
150             }
151         }
152     }
153     
154     //得到(theta(k)*X(ik)-Y(i))即(h(xj)-yj)
155     public double[] getDerivation(){
156         
157         double [] deff = new double[row];
158         
159         for(int i=0;i<row;i++){
160             double h = getHypothesisFunc(i);
161             deff[i]=h-trainData[i][column-1];
162         }
163         return deff;
164     }
165     
166     //得到theta(k)*X(ik)
167     public double getHypothesisFunc(int i){ //i为具体的某一行
168             
169         double result=0;
170         for(int k=0;k<column-1;k++){
171             result+=theta[k]*trainData[i][k];
172         }
173         return result;
174     }
175         
176     //打印训练集
177     public void printTrainData(){
178         
179         System.out.printf("\n训练集:\n");
180         
181         for(int i=0;i<row;i++){
182             System.out.printf("第"+i+"行:");
183             
184             for(int j=0;j<column;j++){
185                 System.out.printf(trainData[i][j]+" ");
186             }
187             System.out.printf("\n");
188         }
189         System.out.printf("\n");
190             
191     }
192     
193     //打印theta值
194     public void printTheta(){
195         
196         System.out.printf("Theta集:\n");
197         for(int j=0;j<column-1;j++){
198             System.out.printf(theta[j]+" ");
199         }
200         System.out.printf("\n");
201     }
202     
203     //预测过程,即将theta带入h函数
204     public double predict(double[] newData){
205         
206         double h=0.0;
207         for(int i=0;i<column-1;i++){
208             h+=newData[i]*theta[i];
209         }
210         return h;
211     }
212 }
LinearRegression

 

根据模型进行测试:

命令行中的输入为:2.3 4.6

 1 public class LinearMain {
 2 
 3     public static void main(String[] args){
 4         
 5         String fileName = "C:\\Users\\CassieLiu\\Desktop\\train.txt";
 6         Linear lin = new Linear(fileName,0.005,100);
 7         lin.printTrainData();
 8         lin.printTheta();
 9         
10         //进行预测,数值在命令行参数里面
11         int len = args.length;
12         if(len!=(lin.getColumnFromFile(fileName)-1)){ //测试数据没有y值
13             System.out.printf("请输入对应该模型的样本!\n");
14             return;
15         }
16         else{
17             double[] arg = new double[len+1];
18             arg[0]=1.0; //给x0赋值
19             for(int i=0;i<len;i++){
20                 arg[i+1] = Double.parseDouble(args[i]);
21             }
22             
23             double result = lin.predict(arg);
24             System.out.printf("根据模型预测出的值为:"+result);
25         }
26             
27         
28     }
29 }
main

 

最后的输出结果为:

 

有不足之处请指出!

本文为博主原创博文,未经许可请勿转载!

转载于:https://www.cnblogs.com/CassieRyu633/p/4795627.html

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值