Adam在做线性回归时也比较常用的方法,很多时候比梯度下降好用。它的实现和梯度下降很类似,增加了超参数 beta1 和 beta2 控制了这些移动均值的衰减率,具体介绍可详见:https://www.cnblogs.com/yifdu25/p/8183587.html 。Adam带来的效果是:训练的时候,可以极度接近极限值,而不会再逐渐偏离。梯度下降会有可能发生偏离越来越大的情况。
如下的2个图就是对算法的描述,摘自上面的url。
有了这个图,我们就可以写代码了,它的实现与梯度下降类似,可以在梯度下降的代码基础上修改。
public class Adam {
private static final Logger logger = LoggerFactory.getLogger(Adam.class);
private double learningRate;
private LinearFormula linearFormula;
private boolean needBias = true;
private double b1 = 0.9d; //英文算法描述中的 b1
private double b2 = 0.999d; //英文算法描述中的 b2
private double e = Math.pow(10, -8); //英文算法描述中的 e (常数)
public Adam(double learningRate, LinearFormula linearFormula, boolean needBias){
this.learningRate = learningRate;
this.linearFormula = linearFormula;
this.needBias = needBias;
}
// 这个方法和上一篇 梯度下降的一致
private double sum(double[] y, double[] p, double[] x){
boolean isXNotEmpty = ArrayUtils.isNotEmpty(x);
double sum = 0;
for (int i = 0; i < y.length; i++){
double v = p[i] - y[i];
if (isXNotEmpty) {
v = v * x[i];
}
sum = sum + v;
}
sum = sum / y.length;
return sum;
}
// 这个方法也是和梯度下降的一致
private double[] generateArray(int length, double pointValue){
double[] ret = new double[length];
for (int i = 0; i < length; i++){
ret[i] = pointValue;
}
return ret;
}
public double[] train(double[] y, double expectMape, int trainCount, double[]... x){
Map<Double, Double[]> map = new HashMap<Double, Double[]>();
Map<Double, Double> mapeMap = new HashMap<Double, Double>();
Map<Double, Double[]> forecastMap = new HashMap<Double, Double[]>();
double[] lsa = new double[trainCount];
int length = y.length;
if (length == 0) {
logger.error("wrong-------------------------length: " + length);
}
int xlength = x.length;
int alength = xlength + 1; // 参数a的数量是变量个数+1
logger.info("length: " + length + ", xlength: " + xlength + ", alength: " + alength);
double[] a = generateArray(alength, 0.5d);
a[0] = 0.0d;
double[] m = generateArray(alength, 0.0d);
double[] v = generateArray(alength, 0.0d);
// double alpha = learningRate * Math.sqrt(1 - b2) / (1 - b1);
double minMape = 1.0d;
for (int i = 0 ; i < trainCount; i++){
linearFormula.setParameters(a);
double[] p = linearFormula.calculateResult(x);
Double[] ao = new Double[alength];
for (int j = 0; j < alength; j++){
ao[j] = a[j];
}
if (i % 5000 == 0 || i < 5) {
MatrixHelper.outputMatrix1(i + " a", a);
}
double stepb1 = Math.pow(b1, (double) (i + 1)); // 参考tensorflow
double stepb2 = Math.pow(b2, (double) (i + 1)); // 参考tensorflow
// 算法循环
for (int j = 0; j < alength; j++){
int xIdx = j - 1;
if (needBias == false && xIdx < 0) {
continue;
}
double g = 0.0d;
if (j == 0) {
g = sum(y, p, null); // 计算 a0 的值,也就是没有变量的那个参数。例如:y = a0 + a1 * x1
} else {
g = sum(y, p, x[xIdx]);
}
m[j] = b1 * m[j] + (1 - b1) * g;
v[j] = b2 * v[j] + (1 - b2) * g * g;
double mtemp = m[j] / (1 - stepb1);
double vtemp = v[j] / (1 - stepb2);
a[j] = a[j] - learningRate * mtemp / (Math.sqrt(vtemp) + e);
// a[j] = a[j] - alpha * m[j] / (Math.sqrt(v[j]) + e);
}
lsa[i] = calculateLeastSquare(y, p);
map.put(lsa[i], ao);
Double[] po = new Double[p.length];
int iii = 0;
for (double pv : p) {
po[iii++] = pv;
}
forecastMap.put(lsa[i], po);
if (expectMape >= 0) {
double mape = getMAPE(y, p);
if (minMape > mape) {
minMape = mape;
}
mapeMap.put(lsa[i], mape);
if (expectMape >= mape) { // 达到预期,不再训练
logger.info("reach the expected result, end.");
break;
}
}
}
double min = getMinValueFrom1D(lsa);
logger.info("min: " + min + ", mape: " + mapeMap.get(min));
double[] p = new double[y.length];
int iii = 0;
for (Double pv : forecastMap.get(min)) {
p[iii++] = pv.doubleValue();
}
Double[] finalA = map.get(min);
double[] ret = new double[finalA.length];
int idx = 0;
for (Double fa : finalA){
ret[idx] = 0.0;
if (fa != null){
ret[idx] = fa.doubleValue();
}
idx++;
}
logger.info("ret: " + ret[0] + ", ret1: " + ret[1] + ", minMape: " + minMape + ", expectMape: " + expectMape);
return ret;
}
public static double getMAPE(double[] source, double[] predict){
Validate.isTrue(source.length > 0,
"The array length of the parameters source and predict should be more than 0.");
Validate.isTrue(source.length == predict.length,
"The array length of the parameters source and predict should be same. source.length: " + source.length +
", predict.length: " + predict.length);
double ret = 0.0;
for (int i = 0; i < source.length; i++){
double base = source[i];
if (base == 0){
base = 1;
}
ret = ret + Math.abs((source[i] - predict[i]) / base);
}
ret = ret / source.length;
return ret;
}
public static double getTopDifference(double[] source, double[] predict, int topNumber){
Validate.isTrue(topNumber > 0 && topNumber < 100,
"The parameter topNumber should be more than 0 and less then 100.");
Validate.isTrue(source.length > 0,
"The array length of the parameters source and predict should be more than 0.");
Validate.isTrue(source.length == predict.length,
"The array length of the parameters source and predict should be same.");
double ret = 0.0;
double[] difference = new double[source.length];
for (int i = 0; i < source.length; i++){
double base = source[i];
if (base == 0){
base = 1;
}
difference[i] = Math.abs((source[i] - predict[i]) / base);
}
Arrays.sort(difference); // order by ascent
double percent = ((double) topNumber) / 100.0d;
int idx = (int)(percent * (source.length - 1));
ret = difference[idx];
return ret;
}
public static double getMinValueFrom1D(double[] xa){
double ret = xa[0];
for (double d : xa){
if (d < ret){
ret = d;
}
}
return ret;
}
// 计算最小二乘的结果
public static double calculateLeastSquare(double[] y, double[] p){
double ret = 0;
for (int i = 0; i < y.length; i++){
double lack = y[i] - p[i];
ret += (lack * lack);
}
ret = ret / (2 * y.length);
return ret;
}
}
这个算法的测试方法可以用 梯度下降的。可参考 Java程序员学算法(4) - 梯度下降(Gradient Descent)