Java中如何求偏导_LinearRegression.java

package algorithm;

import java.sql.Connection;

import java.sql.DriverManager;

import java.sql.PreparedStatement;

import java.sql.ResultSet;

import java.sql.SQLException;

import java.sql.Statement;

import java.util.HashMap;

import java.util.Map;

/*

* http://www.cnblogs.com/wzm-xu/p/4062266.html

* 给定训练集,根据训练集得到一个线性函数,然后测试这个函数训练的好不好

* 现在先完成训练,并展示结果

* 后续可以加上验证,训练集:测试集=4:1

* 或者是十折交叉验证

* 只在期末给出,因为只有期末的数据量足够大,而且数据的解读也需要人为给出

* 使用梯度下降法进行多元线性回归

* 参考Andrew N.g的机器学习

*/

public class LinearRegression {

private double[][] trainData;// 训练数据

private int row;// 训练行数

private int column;// 训练数据列数,这边还需要再加上一列常数项,在原来的基础上加一即可

private double[] theta;// 参数向量

private double alpha;// 步长

private int iteration;// 迭代次数

private String[] queries = { "picturenum", "codelength", "browsenum", "commentnum" };

public double[][] getTrainData() {

return trainData;

}

public void setTrainData(double[][] trainData) {

this.trainData = trainData;

}

public int getRow() {

return row;

}

public void setRow(int row) {

this.row = row;

}

public int getColumn() {

return column;

}

public void setColumn(int column) {

this.column = column;

}

public double[] getTheta() {

return theta;

}

public void setTheta(double[] theta) {

this.theta = theta;

}

public double getAlpha() {

return alpha;

}

public void setAlpha(double alpha) {

this.alpha = alpha;

}

public int getIteration() {

return iteration;

}

public void setIteration(int iteration) {

this.iteration = iteration;

}

public LinearRegression() throws SQLException {

String url = "jdbc:mysql://localhost:3306/test1";

String username = "root";

String password = "123456";

Connection connection = DriverManager.getConnection(url, username, password);

try {

Statement statement = connection.createStatement();

ResultSet resultSet = statement.executeQuery("select count(*) from data");

while (resultSet.next()) {

row = resultSet.getInt(1);

}

column = 5;

} finally {

connection.close();

}

trainData = new double[row][column + 1];

loadTrainData();

// new一个参数向量,少一列y

theta = new double[column];

initialParameter();

}

private void initialParameter() {

// 初始化学习率(步长)和迭代次数

alpha = 0.001;

iteration = 100000;

// 并初始化该参数向量,一开始全部是1.0

for (int i = 0; i < theta.length; i++) {

theta[i] = 1.0;

}

}

// 装载训练数据

private void loadTrainData() throws SQLException {

String url = "jdbc:mysql://localhost:3306/test1";

String username = "root";

String password = "123456";

Connection connection = DriverManager.getConnection(url, username, password);

// 常数项

for (int i = 0; i < row; i++) {

trainData[i][0] = 1.0;

}

try {

Statement statement = connection.createStatement();

int cnt = 0;

String[][] sqls = {

{"select avg(picturenum) from data", "select std(picturenum) from data"},

{"select avg(codelength) from data", "select std(codelength) from data"},

{"select avg(browsenum) from data", "select std(browsenum) from data"},

{"select avg(commentnum) from data", "select std(commentnum) from data"}

};

ResultSet resultSet = null;

PreparedStatement preparedStatement = null;

double[][] tmp= new double[queries.length][2];

for (int i = 0; i < queries.length; i++) {

for (int j = 0; j < 2; j++) {

preparedStatement = connection.prepareStatement(sqls[i][j]);

resultSet = preparedStatement.executeQuery();

while (resultSet.next()) {

tmp[i][j] = resultSet.getDouble(1);

}

}

}

resultSet = statement.executeQuery("select * from data");

while (resultSet.next()) {

// 常数项需要空出

for (int i = 0; i < queries.length; i++) {

trainData[cnt][i + 1] = resultSet.getInt(queries[i]);

trainData[cnt][i + 1] = (trainData[cnt][i + 1] - tmp[i][0]) / tmp[i][1];

}

// 分数

trainData[cnt][5] = resultSet.getInt("score");

cnt++;

}

resultSet.close();

statement.close();

} finally {

connection.close();

}

}

// 训练样本得到参数值

public void trainTheta() {

int iteration = this.iteration;

while ((iteration--) > 0) {

// 计算每个theta的偏导

// partialDerivative := sum(...) / m

double[] partialDerivative = computePartialDerivative();

// 更新每个theta,同时更新

for (int i = 0; i < theta.length; i++) {

theta[i] -= alpha * partialDerivative[i];

}

}

}

// 返回一个偏导数的向量

private double[] computePartialDerivative() {

double[] partialDerivative = new double[theta.length];

for (int i = 0; i < partialDerivative.length; i++) {

partialDerivative[i] = computePartialDerivativeForEach(i);

}

return partialDerivative;

}

// 对每一个theta求其偏导

// partialDerivative == sum(...) / m, row == m

private double computePartialDerivativeForEach(int c) {

double sum = 0.0;

// 计算求和

for (int i = 0; i < row; i++) {

sum += computeGeneralTerm(i, c);

}

return sum / row;

}

// 计算sigma中的通项

private double computeGeneralTerm(int i, int j) {

double res = 0.0;

// 选取第i个样本

double[] r = trainData[i];

// 计算h(x(i))

for (int k = 0; k < r.length - 1; k++) {

res += theta[k] * r[k];

}

// r[r.length - 1] == y

// 计算预测值和真实值之间的差值

res -= r[r.length - 1];

// 乘上该样本第j个变量

res *= r[j];

return res;

}

// 返回参数向量

public Map returnTheta() {

Map map = new HashMap<>();

map.put("Constant", theta[0]);

for (int i = 1; i < theta.length; i++) {

map.put(queries[i - 1], theta[i]);

}

return map;

}

}

一键复制

编辑

Web IDE

原始数据

按行查看

历史

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值