前提:需要64位的JDK,32位的运行不了deeplearning4j
1、Maven添加deeplearning4j相关的jar包:
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-M1.1</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-M1.1</version>
</dependency>
2、构建LeNet-5神经网络训练模型,MNIST数据集在deeplearning4j包中能够直接加载,最终的训练结果存于model.zip中。
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
public class MnistTrain {
public void train() throws Exception {
// 设置训练数据和测试数据迭代器
DataSetIterator trainIter = new MnistDataSetIterator(32, true, 12345);
DataSetIterator testIter = new MnistDataSetIterator(32, false, 12345);
// 构建神经网络的配置
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.l2(0.0005)
.updater(new Adam(0.001))
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(1)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
.stride(1, 1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(4, new DenseLayer.Builder().nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(10)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
// 初始化模型并设置参数
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(10));
// 训练模型
for (int i = 0; i < 5; i++) {
model.fit(trainIter);
}
// 在测试数据上评估模型
Evaluation eval = model.evaluate(testIter);
System.out.println(eval.stats());
// 将模型保存到本地磁盘
File locationToSave = new File("model.zip"); // 将要保存的文件位置
boolean saveUpdater = true; // 是否保存updater(用于进行模型参数更新)
ModelSerializer.writeModel(model, locationToSave, saveUpdater);
}
}
3、使用训练好的模型进行预测
先构建预测器,用于加载模型
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.io.IOException;
public class MnistPredictor {
private MultiLayerNetwork model;
public MnistPredictor() {
// 加载预训练好的LeNet模型
File locationToSave = new File("./lib/model.zip");
try {
model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
} catch (IOException e) {
e.printStackTrace();
}
}
public int predict(int[][] inputImage) {
// 将输入图像转换为INDArray格式,并进行归一化处理
INDArray input = Nd4j.create(inputImage).reshape(1, 1, 28, 28).divi(255.0);
// 对图像进行推理,并返回预测结果的索引
INDArray output = model.output(input);
return Nd4j.argMax(output, 1).getInt(0);
}
}
然后写主函数,实现预测
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
public class MnistPredict {
public static void main(String[] args) throws Exception {
// 加载预训练好的模型
MnistPredictor predictor = new MnistPredictor();
// 加载图片
BufferedImage img = ImageIO.read(new File("img.jpg"));
int width = img.getWidth();
int height = img.getHeight();
// 将RGB图片变为灰度图,并将每个像素存于二维数组
int[][] grayArray = new int[height][width];
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
int rgb = img.getRGB(j, i);
int gray = (int) (0.2989 * ((rgb >> 16) & 0xff) + 0.5870 * ((rgb >> 8) & 0xff) + 0.1140 * (rgb & 0xff));
// 二值化
if (gray <= 120)
grayArray[i][j] = 255;
}
}
// 对当前的二维数组resize,因为训练好的模型要求输入图像大小为28×28
int[][] expandedSubArray = resizeImage(grayArray, 28, 28);
// 保存,看看处理后的图像是什么样的,检查有没有处理错误
save(expandedSubArray, "./img2.jpg");
// 预测
int prediction = predictor.predict(expandedSubArray);
System.out.println("The input image is predicted to be digit " + prediction);
}
private static void save(int[][] data, String path) {
int width = data[0].length;
int height = data.length;
BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int gray = data[y][x];
int rgb = (gray << 16) | (gray << 8) | gray;
image.setRGB(x, y, rgb);
}
}
try {
ImageIO.write(image, "png", new File(path));
} catch (Exception e) {
e.printStackTrace();
}
}
private static int[][] resizeImage(int[][] input, int newWidth, int newHeight) {
int[][] output = new int[newWidth][newHeight];
int height = input.length;
int width = input[0].length;
float widthRatio = (float) width / newWidth;
float heightRatio = (float) height / newHeight;
for (int y = 0; y < newHeight; y++) {
for (int x = 0; x < newWidth; x++) {
int px = (int) (x * widthRatio);
int py = (int) (y * heightRatio);
float xDiff = (x * widthRatio) - px;
float yDiff = (y * heightRatio) - py;
int pixelTopLeft = input[py][px];
int pixelTopRight = (px == width - 1) ? pixelTopLeft : input[py][px + 1];
int pixelBottomLeft = (py == height - 1) ? pixelTopLeft : input[py + 1][px];
int pixelBottomRight = (px == width - 1 || py == height - 1) ? pixelBottomLeft : input[py + 1][px + 1];
float topAvg = pixelTopLeft + xDiff * (pixelTopRight - pixelTopLeft);
float bottomAvg = pixelBottomLeft + xDiff * (pixelBottomRight - pixelBottomLeft);
float avg = topAvg + yDiff * (bottomAvg - topAvg);
if (avg <= 100)
output[y][x] = 0;
else
output[y][x] = 255;
// output[y][x] = (int) avg;
}
}
return output;
}
}