《高等数学》中的最小二乘法
书中公式
以上是书上的最小二乘法,接下来我们使用dl4j训练求解
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class SingleRegression2 {
private static final int seed = 12345;
private static final int nEpochs = 1200;//训练次数
private static final double learningRate = 0.01;//学习率
public static void main(String[] args) {
int numInput = 1;
int numOutputs = 1;
.seed(seed)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(learningRate)).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(numInput)
.nOut(numOutputs).build())
.pretrain(false)
.backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
System.out.println(net.summary());
DataSetIterator iterator = getTrainingData();
for (int i = 0; i < nEpochs; i++) {
iterator.reset();
net.fit(iterator);
Map<String, INDArray> params = net.paramTable();
params.forEach((key, value) -> System.out.println("key:" + key + ", value = " + value));
}
int lenth = 20;
double[] data = new double[lenth];
for (int i = 0; i < lenth; i++) {
data[i] = i;
}
// 测试
final INDArray input = Nd4j.create(data, new int[] { data.length, 1 });
INDArray out = net.output(input, false);
System.out.println(out);
}
private static DataSetIterator getTrainingData() {
//输入x与y 作为训练的数据
double[] input = new double[] { 0, 1, 2, 3, 4, 5, 6, 7 };
double[] output = new double[] { 27.0d, 26.8d, 26.5d, 26.3d, 26.1d, 25.7d, 25.3d, 24.8d };
INDArray inputNDArray = Nd4j.create(input, new int[] { input.length, 1 });
INDArray outPut = Nd4j.create(output, new int[] { output.length, 1 });
DataSet dataSet = new DataSet(inputNDArray, outPut);
List<DataSet> listDs = dataSet.asList();
return new ListDataSetIterator(listDs, listDs.size());
}
}
执行程序
中间省略训练的日志,最后执行结果
-0.3027 和 27.1205与书中的答案已经很接近了
我们修改训练次数位2000次,执行结果如下图,-0.3035 与 27.1247 更为接近