Java使用deeplearning4j实现MNIST手写数字识别

前提:需要64位的JDK,32位的运行不了deeplearning4j

1、Maven添加deeplearning4j相关的jar包:

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-M1.1</version>
</dependency>
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-M1.1</version>
</dependency>

2、构建LeNet-5神经网络训练模型,MNIST数据集在deeplearning4j包中能够直接加载,最终的训练结果存于model.zip中。

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;

public class MnistTrain {

    public void train() throws Exception {
        // 设置训练数据和测试数据迭代器
        DataSetIterator trainIter = new MnistDataSetIterator(32, true, 12345);
        DataSetIterator testIter = new MnistDataSetIterator(32, false, 12345);
        // 构建神经网络的配置
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(12345)
                .l2(0.0005)
                .updater(new Adam(0.001))
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        .nIn(1)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(4, new DenseLayer.Builder().nOut(500).build())
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(10)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(28, 28, 1))
                .build();

        // 初始化模型并设置参数
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();

        model.setListeners(new ScoreIterationListener(10));

        // 训练模型
        for (int i = 0; i < 5; i++) {
            model.fit(trainIter);
        }

        // 在测试数据上评估模型
        Evaluation eval = model.evaluate(testIter);
        System.out.println(eval.stats());

        // 将模型保存到本地磁盘
        File locationToSave = new File("model.zip"); // 将要保存的文件位置
        boolean saveUpdater = true; // 是否保存updater(用于进行模型参数更新)
        ModelSerializer.writeModel(model, locationToSave, saveUpdater);
    }
}

3、使用训练好的模型进行预测

先构建预测器,用于加载模型

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.File;
import java.io.IOException;

public class MnistPredictor {
    private MultiLayerNetwork model;

    public MnistPredictor() {
        // 加载预训练好的LeNet模型
        File locationToSave = new File("./lib/model.zip");
        try {
            model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public int predict(int[][] inputImage) {
        // 将输入图像转换为INDArray格式,并进行归一化处理
        INDArray input = Nd4j.create(inputImage).reshape(1, 1, 28, 28).divi(255.0);
        // 对图像进行推理,并返回预测结果的索引
        INDArray output = model.output(input);
        return Nd4j.argMax(output, 1).getInt(0);
    }
}

然后写主函数,实现预测

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;

public class MnistPredict {
    public static void main(String[] args) throws Exception {
        // 加载预训练好的模型
        MnistPredictor predictor = new MnistPredictor();

        // 加载图片
        BufferedImage img = ImageIO.read(new File("img.jpg"));
        int width = img.getWidth();
        int height = img.getHeight();

        // 将RGB图片变为灰度图,并将每个像素存于二维数组
        int[][] grayArray = new int[height][width];
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                int rgb = img.getRGB(j, i);
                int gray = (int) (0.2989 * ((rgb >> 16) & 0xff) + 0.5870 * ((rgb >> 8) & 0xff) + 0.1140 * (rgb & 0xff));
                // 二值化
                if (gray <= 120)
                    grayArray[i][j] = 255;
            }
        }

        // 对当前的二维数组resize,因为训练好的模型要求输入图像大小为28×28
        int[][] expandedSubArray = resizeImage(grayArray, 28, 28);

        // 保存,看看处理后的图像是什么样的,检查有没有处理错误
        save(expandedSubArray, "./img2.jpg");
        
        // 预测
        int prediction = predictor.predict(expandedSubArray);

        System.out.println("The input image is predicted to be digit " + prediction);
    }

    private static void save(int[][] data, String path) {
        int width = data[0].length;
        int height = data.length;
        BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);

        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int gray = data[y][x];
                int rgb = (gray << 16) | (gray << 8) | gray;
                image.setRGB(x, y, rgb);
            }
        }

        try {
            ImageIO.write(image, "png", new File(path));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private static int[][] resizeImage(int[][] input, int newWidth, int newHeight) {
        int[][] output = new int[newWidth][newHeight];
        int height = input.length;
        int width = input[0].length;
        float widthRatio = (float) width / newWidth;
        float heightRatio = (float) height / newHeight;
        for (int y = 0; y < newHeight; y++) {
            for (int x = 0; x < newWidth; x++) {
                int px = (int) (x * widthRatio);
                int py = (int) (y * heightRatio);
                float xDiff = (x * widthRatio) - px;
                float yDiff = (y * heightRatio) - py;
                int pixelTopLeft = input[py][px];
                int pixelTopRight = (px == width - 1) ? pixelTopLeft : input[py][px + 1];
                int pixelBottomLeft = (py == height - 1) ? pixelTopLeft : input[py + 1][px];
                int pixelBottomRight = (px == width - 1 || py == height - 1) ? pixelBottomLeft : input[py + 1][px + 1];
                float topAvg = pixelTopLeft + xDiff * (pixelTopRight - pixelTopLeft);
                float bottomAvg = pixelBottomLeft + xDiff * (pixelBottomRight - pixelBottomLeft);
                float avg = topAvg + yDiff * (bottomAvg - topAvg);
                if (avg <= 100)
                    output[y][x] = 0;
                else
                    output[y][x] = 255;
//                output[y][x] = (int) avg;
            }
        }
        return output;
    }
}

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

cyc头发还挺多的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值