我们看看dl4j用无预训练的自编码网络来识别异常手写数字,常规数字重构错误低,异常数字重构错误高
public class MNISTAnomalyExample { public static void main(String[] args) throws Exception { //Set up network. 784 in/out (as MNIST images are 28x28).//设置网络,老套路,图片是28*28 //784 -> 250 -> 10 -> 250 -> 784自编码各层节点数量是784,250,10,250,784,最后肯定只用前3层 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) .iterations(1) .weightInit(WeightInit.XAVIER) .updater(Updater.ADAGRAD)//更新器采用自动更改学习速率,梯度越大,学习率衰减越快,梯度越小,学习率衰减越慢,与以往梯度模之和的开平方成反比 .activation("relu") .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(0.05) .regularization(true).l2(0.0001) .list() .layer(0, new DenseLayer.Builder().nIn(784).nOut(250) .build()) .layer(1, new DenseLayer.Builder().nIn(250).nOut(10) .build()) .layer(2, new DenseLayer.Builder().nIn(10).nOut(250) .build()) .layer(3, new OutputLayer.Builder().nIn(250).nOut(784) .lossFunction(LossFunctions.LossFunction.MSE) .build()) .pretrain(false).backprop(true) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(1))); //Load data and split into training and testing sets. 40000 train, 10000 test//装载数据划分训练测试集,4w训练,1w测试 DataSetIterator iter = new MnistDataSetIterator(100,50000,false);//手写数据迭代器,批大小100,5w是数据量,是否二值化否 List<INDArray> featuresTrain = new ArrayList<>(); List<INDArray> featuresTest = new ArrayList<>();