梯度下降是做线性回归时比较常用的方法,关于线性回归和梯度下降的详细介绍可详见:https://blog.csdn.net/xiazdong/article/details/7950084 ,这里用到的数学知识比较多了,推导过程真心看不懂了,不过幸好最终的公式(文章最后的公式)还能看个大概,依葫芦画瓢还能写成Code。其实里面有个重要的概念 Cost Function,而选用是 最小二乘法,就是为了对比线性公式计算后的值与实际值直接的偏离。
在线性回归中,常用到多项式也就是多变量的情况,比如,一个公式 y = a1 * x1 + a2 * x2 + ...... + b 。其中 y 和 各个x都是基于时间的已知数据,每一组x都有对应的y。各个a和b是未知的,而线性回归的流程就是先给出任意a和b的值,然后把每一组x代入公式,这样就能够求出yy ,最后需要对比 yy 和 y的偏差,追求yy和y的偏差小。每变换一组a和b的值就可以算出新的yy,然后再和y对比偏差,如此穷举循环,找出于实际y偏差最小的那个yy,此时对应的这组a和b的值就是最终结果,成为最终的完整公式,用于根据新的x来计算 y,也就相当于做预测。上述的流程就是穷举求极限的过程(极限小), Cost Function 选用最小二乘,那就是穷举得出 yy和y的最小二乘的极限小值,这个求极限就可以转变成 偏导数的问题了,这样最小二乘中的 平方 就变成一阶了 。最后就是上面提到blog里面的梯度下降公式了:
而本篇的内容将上面的数学公式变成Java Code,首先创建用于多项式的接口
公式接口
public interface LinearFormula {
// 设置 参数值,也就是 a
void setParameters(double[] paramters);
// 根据 x 计算
double[] calculateResult(double[]... x);
}
然后是 梯度下降的实现:
public class GradientDescent {
private static final Logger logger = LoggerFactory.getLogger(GradientDescent.class);
private double learningRate;
private LinearFormula linearFormula;
// 是否需要 bias (偏移量)也就是 公式里面的b
private boolean needBias = true;
public GradientDescent(double learningRate, LinearFormula linearFormula, boolean needBias){
this.learningRate = learningRate;
this.linearFormula = linearFormula;
this.needBias = needBias;
}
private double sum(double[] y, double[] p, double[] x){
boolean isXNotEmpty = ArrayUtils.isNotEmpty(x);
double sum = 0;
for (int i = 0; i < y.length; i++){
double v = p[i] - y[i];
if (isXNotEmpty) {
v = v * x[i];
}
sum = sum + v;
}
return sum;
}
private double[] generateArray(int length, double pointValue){
double[] ret = new double[length];
for (int i = 0; i < length; i++){
ret[i] = pointValue;
}
return ret;
}
/**
* 训练方法,训练数据的时调用此方法
* @param y
* @param expectMape 由于梯度下降会出现偏差最小之后,就行训练的话,会偏差变大的情况。也就是在一定的训练次数中,偏差会有变小再变大的情况。这个值就是指定期望的Mape指标,到达此值的时候就退出,不再继续训练,也就是符合期望了。
* @param trainCount
* @param x
* @return 就是 公式中的各个参数 a 和b
*/
public double[] train(double[] y, double expectMape, int trainCount, double[]... x){
Map<Double, Double[]> map = new HashMap<Double, Double[]>();
Map<Double, Double> mapeMap = new HashMap<Double, Double>();
Map<Double, Double[]> forecastMap = new HashMap<Double, Double[]>();
double[] lsa = new double[trainCount];
int length = y.length;
if (length == 0) {
logger.error("wrong-------------------------length: " + length);
}
int xlength = x.length;
int alength = xlength + 1; // 参数a的数量是变量个数+1
logger.info("length: " + length + ", xlength: " + xlength + ", alength: " + alength);
double[] a = generateArray(alength, 1.0d);
double minMape = 1.0d;
for (int i = 0 ; i < trainCount; i++){
linearFormula.setParameters(a); // 每次训练都是重复制定新参数,然后计算结果,最后比较
double[] p = linearFormula.calculateResult(x);
Double[] ao = new Double[alength];
for (int j = 0; j < alength; j++){
ao[j] = a[j];
}
if (needBias) {
a[0] = a[0] - (learningRate / length) * sum(y, p, null);// 计算 a0 的值,也就是没有变量的那个参数。例如:y = a0 + a1 * x1
}
for (int j = 0; j < xlength; j++){
a[j + 1] = a[j + 1] - (learningRate / length) * sum(y, p, x[j]);
}
lsa[i] = calculateLeastSquare(y, p);
map.put(lsa[i], ao);
Double[] po = new Double[p.length];
int iii = 0;
for (double pv : p) {
po[iii++] = pv;
}
forecastMap.put(lsa[i], po);
if (expectMape >= 0) {
double mape = getMAPE(y, p);
if (minMape > mape) {
minMape = mape;
}
mapeMap.put(lsa[i], mape);
if (expectMape >= mape) { // 达到预期,不再训练
logger.info("reach the expected result, end.");
break;
}
}
}
double min = getMinValueFrom1D(lsa);
logger.info("min: " + min + ", mape: " + mapeMap.get(min));
double[] p = new double[y.length];
int iii = 0;
for (Double pv : forecastMap.get(min)) {
p[iii++] = pv.doubleValue();
}
Double[] finalA = map.get(min);
double[] ret = new double[finalA.length];
int idx = 0;
for (Double fa : finalA){
ret[idx] = 0.0;
if (fa != null){
ret[idx] = fa.doubleValue();
}
idx++;
}
logger.info("ret: " + ret[0] + ", ret1: " + ret[1] + ", minMape: " + minMape + ", expectMape: " + expectMape);
return ret;
}
/**
* 工具方法,用于根据偏移值offset打乱 输入 sa,提供测试数据用
* @param da
* @param offset
* @return
*/
public static double[] confusion(double[] da, double offset){
double[] ret = new double[da.length];
for (int i = 0; i < da.length; i++){
double random = Math.random();
double flag = ((int)(random * 10)) % 2 == 0 ? 1 : -1;
double tmp = da[i] * offset;
if (random >= da[i] * offset){
ret[i] = da[i] + flag * tmp;
} else {
ret[i] = da[i] + (da[i] * flag * random);
}
}
return ret;
}
// 计算最小二乘的结果
public static double calculateLeastSquare(double[] y, double[] p){
double ret = 0;
for (int i = 0; i < y.length; i++){
double lack = y[i] - p[i];
ret += (lack * lack);
}
ret = ret / (2 * y.length);
return ret;
}
public static double getMAPE(double[] source, double[] predict){
Validate.isTrue(source.length > 0,
"The array length of the parameters source and predict should be more than 0.");
Validate.isTrue(source.length == predict.length,
"The array length of the parameters source and predict should be same. source.length: " + source.length +
", predict.length: " + predict.length);
double ret = 0.0;
for (int i = 0; i < source.length; i++){
double base = source[i];
if (base == 0){
base = 1;
}
ret = ret + Math.abs((source[i] - predict[i]) / base);
}
ret = ret / source.length;
return ret;
}
public static double getTopDifference(double[] source, double[] predict, int topNumber){
Validate.isTrue(topNumber > 0 && topNumber < 100,
"The parameter topNumber should be more than 0 and less then 100.");
Validate.isTrue(source.length > 0,
"The array length of the parameters source and predict should be more than 0.");
Validate.isTrue(source.length == predict.length,
"The array length of the parameters source and predict should be same.");
double ret = 0.0;
double[] difference = new double[source.length];
for (int i = 0; i < source.length; i++){
double base = source[i];
if (base == 0){
base = 1;
}
difference[i] = Math.abs((source[i] - predict[i]) / base);
}
Arrays.sort(difference); // order by ascent
double percent = ((double) topNumber) / 100.0d;
int idx = (int)(percent * (source.length - 1));
ret = difference[idx];
return ret;
}
public static double getMinValueFrom1D(double[] xa){
double ret = xa[0];
for (double d : xa){
if (d < ret){
ret = d;
}
}
return ret;
}
}
下面就可以开始测试了,例子是 高斯分布公式(GaussianDistribution) ,如下是高斯分布的实现类
高斯分布的实现类
public abstract class BaseLinearFormula implements LinearFormula{
private double[] parameters;
public double[] getParameters() {
return parameters;
}
@Override
public void setParameters(double[] parameters) {
validateParameters(parameters);
this.parameters = parameters;
}
@Override
public double[] calculateResult(double[]... x) {
validateInputData(x);
double[] ret = new double[x[0].length];
for (int i = 0; i < x[0].length; i++){
double[] inputX = new double[x.length];
for (int j = 0; j < x.length; j++){
inputX[j] = x[j][i];
}
ret[i] = retrieveValue(parameters, inputX);
}
return ret;
}
protected abstract double retrieveValue(double[] parameters, double[] data);
protected abstract void validateInputData(double[]... x) throws IllegalArgumentException;
protected abstract void validateParameters(double[] parameters) throws IllegalArgumentException;
}
/**
* 高斯分布的实现,用于测试用。
*/
public class GaussianDistribution extends BaseLinearFormula{
public static double[] generateGaussianDistributioSampleData(int n){
int times = 12;
double[] x = new double[n * times];
double tmp = 0d;
for (int i = 0; i < n * times; i++){
tmp = tmp + (1.0 / times);
x[i] = tmp - n / 2;
}
return x;
}
@Override
protected double retrieveValue(double[] parameters, double[] data) {
double ret = 0;
double exponent = -1 * (Math.pow((data[0] - parameters[0]), 2) / (2 * Math.pow(parameters[1], 2)));
ret = (1 / (parameters[1] * Math.sqrt(2 * Math.PI))) * Math.exp(exponent);
return ret;
}
@Override
protected void validateInputData(double[]... x) throws IllegalArgumentException {
int variableNumber = x.length;
Validate.isTrue(variableNumber == 1, "Gaussian Distribution should have 1 variable");
}
@Override
protected void validateParameters(double[] parameters) throws IllegalArgumentException {
Validate.isTrue(parameters.length == 2, "Gaussian Distribution should have 2 parameters");
}
}
然后就是测试了
// 高斯分布的例子不是很好,但可用于参考用法
GaussianDistribution gaussianDistribution = new GaussianDistribution();
GradientDescent gradientDescent = new GradientDescent(0.1, gaussianDistribution, false);
// 生成测试数据和参数
double[] xtest = GaussianDistribution.generateGaussianDistributioSampleData(6);
double[] a = {0, 0.6};
gaussianDistribution.setParameters(a);
// 根据公式计算结果
double[] sy = gaussianDistribution.calculateResult(xtest);
// 将刚刚计算的值做混淆,然后对比偏差度。主要是为了后续测试用,要用混淆后的数据做测试
double[] cy = GradientDescent.confusion(sy, 0.2);
System.out.println("confusion MAPE: " + GradientDescent.getMAPE(cy, sy) +
", confusion top 90: " + GradientDescent.getTopDifference(cy, sy, 90) +
", confusion top 80: " + GradientDescent.getTopDifference(cy, sy, 80));
System.out.println("a: " + a[0] + ", " + a[1]);
// 基于混淆后的结果 和原输入 进行训练。结果是 公式的参数
double[] atest = gradientDescent.train(cy, 0.0, 6000, xtest);
System.out.println("atest: " + atest[0] + ", " + atest[1]);
gaussianDistribution.setParameters(atest); // 指定参数
double[] p = gaussianDistribution.calculateResult(xtest); // 对原数据进行计算
System.out.println("final predict GaussianDistribution: ", p, ", ");
// 打印偏差
System.out.println("final MAPE: " + GradientDescent.getMAPE(cy, p) +
", final top 90: " + GradientDescent.getTopDifference(cy, p, 90) +
", final top 80: " + GradientDescent.getTopDifference(cy, p, 80));
--------效果不好,供参考-----------------------------------------------------------------------------------------------------------------
输出结果
confusion MAPE: 0.19411157584986596, confusion top 90: 0.25000000000000006, confusion top 80: 0.25
a: 0.0, 0.6
偏差比较大
final predict GaussianDistribution: : 0.0058, 0.0069, 0.0081, 0.0094, 0.011, 0.0128, 0.0148, 0.017, 0.0196, 0.0224, 0.0256, 0.0291, 0.0329, 0.0372, 0.0418, 0.0469, 0.0524, 0.0583, 0.0647, 0.0715, 0.0788, 0.0865, 0.0946, 0.1031, 0.112, 0.1212, 0.1307, 0.1405, 0.1505, 0.1606, 0.1708, 0.181, 0.1911, 0.2011, 0.2109, 0.2204, 0.2295, 0.2381, 0.2462, 0.2537, 0.2605, 0.2665, 0.2717, 0.276, 0.2794, 0.2819, 0.2834, 0.2839, 0.2834, 0.2819, 0.2794, 0.276, 0.2717, 0.2665, 0.2605, 0.2537, 0.2462, 0.2381, 0.2295, 0.2204, 0.2109, 0.2011, 0.1911, 0.181, 0.1708, 0.1606, 0.1505, 0.1405, 0.1307, 0.1212, 0.112, 0.1031,
final MAPE: 1429.2300451464919, final top 90: 988.6810056795385, final top 80: 170.0713463977803