基于DL4J的自动编码器
一、简介
为什么要使用自动编码器? 在实践中,自动编码器通常应用于数据的降噪和降维。 这对于表示学习非常有用,而对于数据压缩则不太有用。
在深度学习中,自动编码器是“尝试”以重建其输入的神经网络。 它可以用作特征提取的一种形式,并且可以将自动编码器堆叠起来以创建“深度”网络。 由自动编码器生成的功能可以输入到其他算法中,以进行分类,聚类和异常检测。
当原始输入数据具有高维且无法轻松绘制时,自动编码器还可用于数据可视化。 通过降维,有时可以将输出压缩到2D或3D空间中,以进行更好的数据探索。
在实际应用当中,异常检测能够用于:网络入侵,欺诈检测,系统监视,传感器网络事件检测(IoT)和异常轨迹感测。
二、自编码器的工作流程
自动编码器包括:
1、编码功能(“编码器”)
2、解码功能(“解码器”)
3、距离函数(“损失函数”)
首先,输入被馈入自动编码器并转换为压缩表示。然后,解码器学习如何从压缩的表示中重建原始输入,在无监督的训练过程中,损失函数有助于纠正解码器产生的错误。 此过程是自动的(因此称为“自动”编码器); 即不需要人工干预。
学习到现在,我们应该已经知道如何使用MultiLayerNetwork和ComputationGraph创建不同的网络配置了,现在,我们将构造一个“堆叠”自动编码器,该编码器对MNIST数字执行异常检测而无需预先训练。而目的是识别异常数字,即不寻常和不典型的数字。从给定数据集的规范中“脱颖而出”的内容,事件或观察结果的识别被广泛称为异常检测。异常检测不需要标注的数据集,并且可以在无监督学习的情况下进行,这很有帮助,因为世界上大多数数据都没有标注。
通常,异常检测使用重构误差来衡量解码器的性能。正常的数据应具有较低的重构误差,而异常值应具有较高的重构误差。
三、基于DL4J的自编码器实现
3.1、导入需要的包
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.List;
3.2、堆叠式自动编码器
以下自动编码器使用两个堆叠的密集层进行编码。 MNIST数字转换为长度为784的平面一维数组(MNIST图像为28x28像素,当我们端对端放置它们时等于784)。在网络中,数据的大小变化情况如下:
784→250→10→250→784
代码如下:
//搭建模型
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new AdaGrad(0.05))
.activation(Activation.RELU)
.l2(0.0001)
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(250)
.build())
.layer(new DenseLayer.Builder().nIn(250).nOut(10)
.build())
.layer(new DenseLayer.Builder().nIn(10).nOut(250)
.build())
.layer(new OutputLayer.Builder().nIn(250).nOut(784)
.activation(Activation.LEAKYRELU)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
//监听器
net.setListeners(Collections.singletonList(new ScoreIterationListener(10)));
3.3、使用MNIST迭代器
像Deeplearning4j的大多数内置迭代器一样,MNIST迭代器扩展了DataSetIterator类。 该API允许简单地实例化数据集并在后台自动下载数据。
代码如下:
//加载数据,并且进行训练集与测试集的划分:40000训练数据,10000测试数据
DataSetIterator iter = new MnistDataSetIterator(100,50000,false);
List<INDArray> featuresTrain = new ArrayList<>();
List<INDArray> featuresTest = new ArrayList<>();
List<INDArray> labelsTest = new ArrayList<>();
Random r = new Random(12345);
while(iter.hasNext()){
DataSet ds = iter.next();
SplitTestAndTrain split = ds.splitTestAndTrain(80, r); //按照8:2的比例划分数据集 (miniBatch = 100)
featuresTrain.add(split.getTrain().getFeatures());
DataSet dsTest = split.getTest();
featuresTest.add(dsTest.getFeatures());
INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //进行独热编码转换: 表示 -> 索引
labelsTest.add(indexes);
}
3.4、无监督训练
现在,我们已经设置了网络配置并与我们的MNIST测试/训练迭代器一起实例化了,训练只需要几行代码。
之前,我们使用setListeners()方法将ScoreIterationListener附加到模型。根据用于运行此代码电脑的浏览器,可以打开调试器/检查器以查看侦听器输出。 由于Deeplearning4j的内部使用SL4J进行日志记录,因此此输出重定向到控制台,并且Zeppelin重定向了该输出。 这有助于减少电脑的混乱情况。
代码如下:
//训练模型
int nEpochs = 3;
for( int epoch=0; epoch<nEpochs; epoch++ ){
for(INDArray data : featuresTrain){
net.fit(data,data);
}
System.out.println("Epoch " + epoch + " complete");
}
3.5、评估模型
现在,我们已经对自动编码器进行了训练,那么,我们将根据测试数据来评估模型。每个示例将被单独打分,并且将构成一个映射,该映射将每个数字与(得分,示例)对列表相关联。
最后,我们将计算N个最佳分数和N个最差分数。
代码如下:
//根据测试数据评估模型
//分别对测试集中的每个样本评分
//组成一个映射,将每个数字与(得分,样本)对列表相关联
//然后找到每位数中N个最佳分数和N个最差分数
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());
for( int i=0; i<featuresTest.size(); i++ ){
INDArray testData = featuresTest.get(i);
INDArray labels = labelsTest.get(i);
int nRows = testData.rows();
for( int j=0; j<nRows; j++){
INDArray example = testData.getRow(j, true);
int digit = (int)labels.getDouble(j);
double score = net.score(new DataSet(example,example));
// 将(得分,样本)对添加到适当的列表
List digitAllPairs = listsByDigit.get(digit);
digitAllPairs.add(new ImmutablePair<>(score, example));
}
}
//Sort each list in the map by score
Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
@Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
return Double.compare(o1.getLeft(),o2.getLeft());
}
};
for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
Collections.sort(digitAllPairs, c);
}
//排序后,为每个数字选择N个最佳分数和N个最差分数(根据重构误差),其中N = 5
List<INDArray> best = new ArrayList<>(50);
List<INDArray> worst = new ArrayList<>(50);
for( int i=0; i<10; i++ ){
List<Pair<Double,INDArray>> list = listsByDigit.get(i);
for( int j=0; j<5; j++ ){
best.add(list.get(j).getRight());
worst.add(list.get(list.size()-j-1).getRight());
}
}
3.6、结果可视化
//默认可视化
if (visualize) {
//可视化最好和最差的数字
MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0, best, "Best (Low Rec. Error)");
bestVisualizer.visualize();
MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0, worst, "Worst (High Rec. Error)");
worstVisualizer.visualize();
}
//可视化方法
public static class MNISTVisualizer {
private double imageScale;
private List<INDArray> digits; //数字(作为行向量),每个INDArray一个
private String title;
private int gridWidth;
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
this(imageScale, digits, title, 5);
}
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
this.imageScale = imageScale;
this.digits = digits;
this.title = title;
this.gridWidth = gridWidth;
}
public void visualize(){
JFrame frame = new JFrame();
frame.setTitle(title);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(0,gridWidth));
List<JLabel> list = getComponents();
for(JLabel image : list){
panel.add(image);
}
frame.add(panel);
frame.setVisible(true);
frame.pack();
}
private List<JLabel> getComponents(){
List<JLabel> images = new ArrayList<>();
for( INDArray arr : digits ){
BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
for( int i=0; i<784; i++ ){
bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
images.add(new JLabel(scaled));
}
return images;
}
}
最终结果如下图所示:
最差的手写数字:
最好的手写数字:
完整代码如下:
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.List;
public class xx {
public static boolean visualize = true;
public static void main(String[] args) throws Exception {
//搭建模型. 784 输入/输出 (MNIST 图片大小为 28x28).
//784 -> 250 -> 10 -> 250 -> 784
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new AdaGrad(0.05))
.activation(Activation.RELU)
.l2(0.0001)
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(250)
.build())
.layer(new DenseLayer.Builder().nIn(250).nOut(10)
.build())
.layer(new DenseLayer.Builder().nIn(10).nOut(250)
.build())
.layer(new OutputLayer.Builder().nIn(250).nOut(784)
.activation(Activation.LEAKYRELU)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(Collections.singletonList(new ScoreIterationListener(10)));
//加载数据,并且进行训练集与测试集的划分:40000训练数据,10000测试数据
DataSetIterator iter = new MnistDataSetIterator(100,50000,false);
List<INDArray> featuresTrain = new ArrayList<>();
List<INDArray> featuresTest = new ArrayList<>();
List<INDArray> labelsTest = new ArrayList<>();
Random r = new Random(12345);
while(iter.hasNext()){
DataSet ds = iter.next();
SplitTestAndTrain split = ds.splitTestAndTrain(80, r); //按照8:2的比例进行划分(from miniBatch = 100)
featuresTrain.add(split.getTrain().getFeatures());
DataSet dsTest = split.getTest();
featuresTest.add(dsTest.getFeatures());
INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //通过独热编码将表示转换为索引
labelsTest.add(indexes);
}
//训练模型
int nEpochs = 3;
for( int epoch=0; epoch<nEpochs; epoch++ ){
for(INDArray data : featuresTrain){
net.fit(data,data);
}
System.out.println("Epoch " + epoch + " complete");
}
//根据测试数据评估模型
//分别对测试集中的每个样本评分
//组成一个映射,将每个数字与(得分,样本)对列表相关联
//然后找到每位数中N个最佳分数和N个最差分数
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());
for( int i=0; i<featuresTest.size(); i++ ){
INDArray testData = featuresTest.get(i);
INDArray labels = labelsTest.get(i);
int nRows = testData.rows();
for( int j=0; j<nRows; j++){
INDArray example = testData.getRow(j, true);
int digit = (int)labels.getDouble(j);
double score = net.score(new DataSet(example,example));
// 将(得分,样本)对添加到适当的列表
List digitAllPairs = listsByDigit.get(digit);
digitAllPairs.add(new ImmutablePair<>(score, example));
}
}
//按分数映射对每个列表进行排序
Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
@Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
return Double.compare(o1.getLeft(),o2.getLeft());
}
};
for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
Collections.sort(digitAllPairs, c);
}
排序后,为每个数字选择N个最佳分数和N个最差分数(根据重构误差),其中N = 5
List<INDArray> best = new ArrayList<>(50);
List<INDArray> worst = new ArrayList<>(50);
for( int i=0; i<10; i++ ){
List<Pair<Double,INDArray>> list = listsByDigit.get(i);
for( int j=0; j<5; j++ ){
best.add(list.get(j).getRight());
worst.add(list.get(list.size()-j-1).getRight());
}
}
//默认进行可视化
if (visualize) {
//可视化最好与最差的数字
MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0, best, "Best (Low Rec. Error)");
bestVisualizer.visualize();
MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0, worst, "Worst (High Rec. Error)");
worstVisualizer.visualize();
}
}
public static class MNISTVisualizer {
private double imageScale;
private List<INDArray> digits; //Digits (as row vectors), one per INDArray
private String title;
private int gridWidth;
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
this(imageScale, digits, title, 5);
}
//设定可视化的图片大小、数字、标题与显示图片的网格大小
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
this.imageScale = imageScale;
this.digits = digits;
this.title = title;
this.gridWidth = gridWidth;
}
public void visualize(){
JFrame frame = new JFrame();
frame.setTitle(title);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(0,gridWidth));
List<JLabel> list = getComponents();
for(JLabel image : list){
panel.add(image);
}
frame.add(panel);
frame.setVisible(true);
frame.pack();
}
private List<JLabel> getComponents(){
List<JLabel> images = new ArrayList<>();
for( INDArray arr : digits ){
BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
for( int i=0; i<784; i++ ){
bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
images.add(new JLabel(scaled));
}
return images;
}
}
}