# 基于梯度下降算法求解线性回归

2088人阅读 评论(0)

### 三：代码实现

public List<DataItem> getData(String fileName) {
List<DataItem> items = new ArrayList<DataItem>();
File f = new File(fileName);
try {
if (f.exists()) {
String line = null;
while((line = br.readLine()) != null) {
String[] data = line.split(",");
if(data != null && data.length == 2) {
DataItem item = new DataItem();
item.x = Integer.parseInt(data[0]);
item.y = Integer.parseInt(data[1]);
}
}
br.close();
}
} catch (IOException ioe) {
System.err.println(ioe);
}
return items;
}

public void normalization(List<DataItem> items) {
float min = 100000;
float max = 0;
for(DataItem item : items) {
min = Math.min(min, item.x);
max = Math.max(max, item.x);
}
float delta = max - min;
for(DataItem item : items) {
item.x = (item.x - min) / delta;
}
}

public float[] gradientDescent(List<DataItem> items) {
int repetion = 1500;
float learningRate = 0.1f;
float[] theta = new float[2];
Arrays.fill(theta, 0);
float[] hmatrix = new float[items.size()];
Arrays.fill(hmatrix, 0);
int k=0;
float s1 = 1.0f / items.size();
float sum1=0, sum2=0;
for(int i=0; i<repetion; i++) {
for(k=0; k<items.size(); k++ ) {
hmatrix[k] = ((theta[0] + theta[1]*items.get(k).x) - items.get(k).y);
}

for(k=0; k<items.size(); k++ ) {
sum1 += hmatrix[k];
sum2 += hmatrix[k]*items.get(k).x;
}

sum1 = learningRate*s1*sum1;
sum2 = learningRate*s1*sum2;

// 更新 参数theta
theta[0] = theta[0] - sum1;
theta[1] = theta[1] - sum2;
}

return theta;
}

public float predict(float input, float[] theta) {
float result = theta[0] + theta[1]*input;
return result;
}

public void drawPlot(List<DataItem> series1, List<DataItem> series2, float[] theta) {
int w = 500;
int h = 500;
BufferedImage plot = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);
Graphics2D g2d = plot.createGraphics();
g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
g2d.setPaint(Color.WHITE);
g2d.fillRect(0, 0, w, h);
g2d.setPaint(Color.BLACK);
int margin = 50;
g2d.drawLine(margin, 0, margin, h);
g2d.drawLine(0, h-margin, w, h-margin);
float minx=Float.MAX_VALUE, maxx=Float.MIN_VALUE;
float miny=Float.MAX_VALUE, maxy=Float.MIN_VALUE;
for(DataItem item : series1) {
minx = Math.min(item.x, minx);
maxx = Math.max(maxx, item.x);
miny = Math.min(item.y, miny);
maxy = Math.max(item.y, maxy);
}
for(DataItem item : series2) {
minx = Math.min(item.x, minx);
maxx = Math.max(maxx, item.x);
miny = Math.min(item.y, miny);
maxy = Math.max(item.y, maxy);
}
// draw X, Y Title and Aixes
g2d.setPaint(Color.BLACK);
g2d.drawString("价格(万)", 0, h/2);
g2d.drawString("面积(平方米)", w/2, h-20);

// draw labels and legend
g2d.setPaint(Color.BLUE);
float xdelta = maxx - minx;
float ydelta = maxy - miny;
float xstep = xdelta / 10.0f;
float ystep = ydelta / 10.0f;
int dx = (w - 2*margin) / 11;
int dy = (h - 2*margin) / 11;

// draw labels
for(int i=1; i<11; i++) {
g2d.drawLine(margin+i*dx, h-margin, margin+i*dx, h-margin-10);
g2d.drawLine(margin, h-margin-dy*i, margin+10, h-margin-dy*i);
int xv = (int)(minx + (i-1)*xstep);
float yv = (int)((miny + (i-1)*ystep)/10000.0f);
g2d.drawString(""+xv, margin+i*dx, h-margin+15);
g2d.drawString(""+yv, margin-25, h-margin-dy*i);
}

// draw point
g2d.setPaint(Color.BLUE);
for(DataItem item : series1) {
float xs = (item.x - minx) / xstep + 1;
float ys = (item.y - miny) / ystep + 1;
g2d.fillOval((int)(xs*dx+margin-3), (int)(h-margin-ys*dy-3), 7,7);
}
g2d.fillRect(100, 20, 20, 10);
g2d.drawString("训练数据", 130, 30);

// draw regression line
g2d.setPaint(Color.RED);
for(int i=0; i<series2.size()-1; i++) {
float x1 = (series2.get(i).x - minx) / xstep + 1;
float y1 = (series2.get(i).y - miny) / ystep + 1;
float x2 = (series2.get(i+1).x - minx) / xstep + 1;
float y2 = (series2.get(i+1).y - miny) / ystep + 1;
g2d.drawLine((int)(x1*dx+margin-3), (int)(h-margin-y1*dy-3), (int)(x2*dx+margin-3), (int)(h-margin-y2*dy-3));
}
g2d.fillRect(100, 50, 20, 10);
g2d.drawString("线性回归", 130, 60);

g2d.dispose();
saveImage(plot);
}

### 四：总结

个人资料
等级：
访问量： 295万+
积分： 2万+
排名： 338
博客专栏
 Java数字图像处理与特效 文章：68篇 阅读：1098408 HTML5 Canvas编程 文章：14篇 阅读：280836
最新评论