Deeplearning4j线性回归

《高等数学》中的最小二乘法

 

 

书中公式 

 

以上是书上的最小二乘法,接下来我们使用dl4j训练求解

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;


public class SingleRegression2 {

	private static final int seed = 12345;
	private static final int nEpochs = 1200;//训练次数
	private static final double learningRate = 0.01;//学习率

	public static void main(String[] args) {
		int numInput = 1;
		int numOutputs = 1;
				.seed(seed)
				.weightInit(WeightInit.XAVIER)
				.updater(new Sgd(learningRate)).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
						.activation(Activation.IDENTITY)
						.nIn(numInput)
						.nOut(numOutputs).build())
				.pretrain(false)
				.backprop(true).build();
		MultiLayerNetwork net = new MultiLayerNetwork(conf);
		net.init();
		System.out.println(net.summary());
		DataSetIterator iterator = getTrainingData();
		for (int i = 0; i < nEpochs; i++) {
			iterator.reset();

			net.fit(iterator);

			Map<String, INDArray> params = net.paramTable();
			params.forEach((key, value) -> System.out.println("key:" + key + ", value = " + value));

		}

		int lenth = 20;
		double[] data = new double[lenth];
		for (int i = 0; i < lenth; i++) {
			data[i] = i;
		}

		// 测试
		final INDArray input = Nd4j.create(data, new int[] { data.length, 1 });
		INDArray out = net.output(input, false);
		System.out.println(out);
	}

	private static DataSetIterator getTrainingData() {
        //输入x与y 作为训练的数据
		double[] input = new double[] { 0, 1, 2, 3, 4, 5, 6, 7 };
		double[] output = new double[] { 27.0d, 26.8d, 26.5d, 26.3d, 26.1d, 25.7d, 25.3d, 24.8d };

		INDArray inputNDArray = Nd4j.create(input, new int[] { input.length, 1 });
		INDArray outPut = Nd4j.create(output, new int[] { output.length, 1 });

		DataSet dataSet = new DataSet(inputNDArray, outPut);
		List<DataSet> listDs = dataSet.asList();

		return new ListDataSetIterator(listDs, listDs.size());
	}
}

执行程序

 中间省略训练的日志,最后执行结果

-0.3027  和 27.1205与书中的答案已经很接近了

 

我们修改训练次数位2000次,执行结果如下图,-0.3035 与 27.1247 更为接近

 

 

 

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

邓霖涛

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值