package algorithm;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.Map;
/*
* http://www.cnblogs.com/wzm-xu/p/4062266.html
* 给定训练集,根据训练集得到一个线性函数,然后测试这个函数训练的好不好
* 现在先完成训练,并展示结果
* 后续可以加上验证,训练集:测试集=4:1
* 或者是十折交叉验证
* 只在期末给出,因为只有期末的数据量足够大,而且数据的解读也需要人为给出
* 使用梯度下降法进行多元线性回归
* 参考Andrew N.g的机器学习
*/
public class LinearRegression {
private double[][] trainData;// 训练数据
private int row;// 训练行数
private int column;// 训练数据列数,这边还需要再加上一列常数项,在原来的基础上加一即可
private double[] theta;// 参数向量
private double alpha;// 步长
private int iteration;// 迭代次数
private String[] queries = { "picturenum", "codelength", "browsenum", "commentnum" };
public double[][] getTrainData() {
return trainData;
}
public void setTrainData(double[][] trainData) {
this.trainData = trainData;
}
public int getRow() {
return row;
}
public void setRow(int row) {
this.row = row;
}
public int getColumn() {
return column;
}
public void setColumn(int column) {
this.column = column;
}
public double[] getTheta() {
return theta;
}
public void setTheta(double[] theta) {
this.theta = theta;
}
public double getAlpha() {
return alpha;
}
public void setAlpha(double alpha) {
this.alpha = alpha;
}
public int getIteration() {
return iteration;
}
public void setIteration(int iteration) {
this.iteration = iteration;
}
public LinearRegression() throws SQLException {
String url = "jdbc:mysql://localhost:3306/test1";
String username = "root";
String password = "123456";
Connection connection = DriverManager.getConnection(url, username, password);
try {
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery("select count(*) from data");
while (resultSet.next()) {
row = resultSet.getInt(1);
}
column = 5;
} finally {
connection.close();
}
trainData = new double[row][column + 1];
loadTrainData();
// new一个参数向量,少一列y
theta = new double[column];
initialParameter();
}
private void initialParameter() {
// 初始化学习率(步长)和迭代次数
alpha = 0.001;
iteration = 100000;
// 并初始化该参数向量,一开始全部是1.0
for (int i = 0; i < theta.length; i++) {
theta[i] = 1.0;
}
}
// 装载训练数据
private void loadTrainData() throws SQLException {
String url = "jdbc:mysql://localhost:3306/test1";
String username = "root";
String password = "123456";
Connection connection = DriverManager.getConnection(url, username, password);
// 常数项
for (int i = 0; i < row; i++) {
trainData[i][0] = 1.0;
}
try {
Statement statement = connection.createStatement();
int cnt = 0;
String[][] sqls = {
{"select avg(picturenum) from data", "select std(picturenum) from data"},
{"select avg(codelength) from data", "select std(codelength) from data"},
{"select avg(browsenum) from data", "select std(browsenum) from data"},
{"select avg(commentnum) from data", "select std(commentnum) from data"}
};
ResultSet resultSet = null;
PreparedStatement preparedStatement = null;
double[][] tmp= new double[queries.length][2];
for (int i = 0; i < queries.length; i++) {
for (int j = 0; j < 2; j++) {
preparedStatement = connection.prepareStatement(sqls[i][j]);
resultSet = preparedStatement.executeQuery();
while (resultSet.next()) {
tmp[i][j] = resultSet.getDouble(1);
}
}
}
resultSet = statement.executeQuery("select * from data");
while (resultSet.next()) {
// 常数项需要空出
for (int i = 0; i < queries.length; i++) {
trainData[cnt][i + 1] = resultSet.getInt(queries[i]);
trainData[cnt][i + 1] = (trainData[cnt][i + 1] - tmp[i][0]) / tmp[i][1];
}
// 分数
trainData[cnt][5] = resultSet.getInt("score");
cnt++;
}
resultSet.close();
statement.close();
} finally {
connection.close();
}
}
// 训练样本得到参数值
public void trainTheta() {
int iteration = this.iteration;
while ((iteration--) > 0) {
// 计算每个theta的偏导
// partialDerivative := sum(...) / m
double[] partialDerivative = computePartialDerivative();
// 更新每个theta,同时更新
for (int i = 0; i < theta.length; i++) {
theta[i] -= alpha * partialDerivative[i];
}
}
}
// 返回一个偏导数的向量
private double[] computePartialDerivative() {
double[] partialDerivative = new double[theta.length];
for (int i = 0; i < partialDerivative.length; i++) {
partialDerivative[i] = computePartialDerivativeForEach(i);
}
return partialDerivative;
}
// 对每一个theta求其偏导
// partialDerivative == sum(...) / m, row == m
private double computePartialDerivativeForEach(int c) {
double sum = 0.0;
// 计算求和
for (int i = 0; i < row; i++) {
sum += computeGeneralTerm(i, c);
}
return sum / row;
}
// 计算sigma中的通项
private double computeGeneralTerm(int i, int j) {
double res = 0.0;
// 选取第i个样本
double[] r = trainData[i];
// 计算h(x(i))
for (int k = 0; k < r.length - 1; k++) {
res += theta[k] * r[k];
}
// r[r.length - 1] == y
// 计算预测值和真实值之间的差值
res -= r[r.length - 1];
// 乘上该样本第j个变量
res *= r[j];
return res;
}
// 返回参数向量
public Map returnTheta() {
Map map = new HashMap<>();
map.put("Constant", theta[0]);
for (int i = 1; i < theta.length; i++) {
map.put(queries[i - 1], theta[i]);
}
return map;
}
}
一键复制
编辑
Web IDE
原始数据
按行查看
历史