Spring Boot集成DeepLearning4j实现图片数字识别

1.什么是DeepLearning4j?

DeepLearning4J(DL4J)是一套基于Java语言的神经网络工具包,可以构建、定型和部署神经网络。DL4J与Hadoop和Spark集成,支持分布式CPU和GPU,为商业环境(而非研究工具目的)所设计。Skymind是DL4J的商业支持机构。 Deeplearning4j拥有先进的技术,以即插即用为目标,通过更多预设的使用,避免多余的配置,让非企业也能够进行快速的原型制作。DL4J同时可以规模化定制。DL4J遵循Apache 2.0许可协议,一切以其为基础的衍生作品均属于衍生作品的作

Deeplearning4j的功能

Deeplearning4j包括了分布式、多线程的深度学习框架,以及普通的单线程深度学习框架。定型过程以集群进行,也就是说,Deeplearning4j可以快速处理大量数据。神经网络可通过[迭代化简]平行定型,与 Java、 Scala 和 Clojure 均兼容。Deeplearning4j在开放堆栈中作为模块组件的功能,使之成为首个为微服务架构打造的深度学习框架。

Deeplearning4j的组件

深度神经网络能够实现前所未有的准确度。对神经网络的简介请参见概览页。简而言之,Deeplearning4j能够让你从各类浅层网络(其中每一层在英文中被称为layer)出发,设计深层神经网络。这一灵活性使用户可以根据所需,在分布式、生产级、能够在分布式CPU或GPU的基础上与Spark和Hadoop协同工作的框架内,整合受限玻尔兹曼机、其他自动编码器、卷积网络或递归网络。 此处为我们已经建立的各个库及其在系统整体中的所处位置:  

dl4j-ecosystem-cn-small

DeepLearning4J用于设计神经网络:

  • Deeplearning4j(简称DL4J)是为Java和Scala编写的首个商业级开源分布式深度学习
  • DL4J与Hadoop和Spark集成,为商业环境(而非研究工具目的)所设计。
  • 支持GPU和CPU
  • 受到 Cloudera, Hortonwork, NVIDIA, Intel, IBM 等认证,可以在Spark, Flink, Hadoop 上运行
  • 支持并行迭代算法架构
  • DeepLearning4J的JavaDoc可在此处获取
  • DeepLearning4J示例的Github代码库请见此处。相关示例的简介汇总请见此处
  • 开源工具 ASF 2.0许可证:github.com/deeplearning4j/deeplearning4j

2.训练模型

训练和测试数据集下载

https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz

MNIST简介
  • MNIST是经典的计算机视觉数据集,来源是National Institute of Standards and Technology (NIST,美国国家标准与技术研究所),包含各种手写数字图片,其中训练集60,000张,测试集 10,000张,
  • MNIST来源于250 个不同人的手写,其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员.,测试集(test set) 也是同样比例的手写数字数据
  • MNIST官网:http://yann.lecun.com/exdb/mnist/
数据集简介

从MNIST官网下载的原始数据并非图片文件,需要按官方给出的格式说明做解析处理才能转为一张张图片,这些事情显然不是本篇的主题,因此咱们可以直接使用DL4J为我们准备好的数据集(下载地址稍后给出),该数据集中是一张张独立的图片,这些图片所在目录的名字就是该图片具体的数字

模型训练

LeNet-5简介

1351564-20180827204056354-1429986291

LeNet-5 结构:
  • 输入层

图片大小为 32×32×1,其中 1 表示为黑白图像,只有一个 channel。

  • 卷积层

filter 大小 5×5,filter 深度(个数)为 6,padding 为 0, 卷积步长 s=1=1,输出矩阵大小为 28×28×6,其中 6 表示 filter 的个数。

  • 池化层

average pooling,filter 大小 2×2(即 f=2=2),步长 s=2=2,no padding,输出矩阵大小为 14×14×6。

  • 卷积层

filter 大小 5×5,filter 个数为 16,padding 为 0, 卷积步长 s=1=1,输出矩阵大小为 10×10×16,其中 16 表示 filter 的个数。

  • 池化层

average pooling,filter 大小 2×2(即 f=2=2),步长 s=2=2,no padding,输出矩阵大小为 5×5×16。注意,在该层结束,需要将 5×5×16 的矩阵flatten 成一个 400 维的向量。

  • 全连接层(Fully Connected layer,FC)

neuron 数量为 120。

  • 全连接层(Fully Connected layer,FC)

neuron 数量为 84。

  • 全连接层,输出层

现在版本的 LeNet-5 输出层一般会采用 softmax 激活函数,在 LeNet-5 提出的论文中使用的激活函数不是 softmax,但其现在不常用。该层神经元数量为 10,代表 0~9 十个数字类别。(图 1 其实少画了一个表示全连接层的方框,而直接用 ^y^ 表示输出层。)  

 
  1. /*******************************************************************************

  2. * Copyright (c) 2020 Konduit K.K.

  3. * Copyright (c) 2015-2019 Skymind, Inc.

  4. *

  5. * This program and the accompanying materials are made available under the

  6. * terms of the Apache License, Version 2.0 which is available at

  7. * https://www.apache.org/licenses/LICENSE-2.0.

  8. *

  9. * Unless required by applicable law or agreed to in writing, software

  10. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT

  11. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the

  12. * License for the specific language governing permissions and limitations

  13. * under the License.

  14. *

  15. * SPDX-License-Identifier: Apache-2.0

  16. ******************************************************************************/

  17. package com.et.dl4j.model;

  18. import lombok.extern.slf4j.Slf4j;

  19. import org.datavec.api.io.labels.ParentPathLabelGenerator;

  20. import org.datavec.api.split.FileSplit;

  21. import org.datavec.image.loader.NativeImageLoader;

  22. import org.datavec.image.recordreader.ImageRecordReader;

  23. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;

  24. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;

  25. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;

  26. import org.deeplearning4j.nn.conf.inputs.InputType;

  27. import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;

  28. import org.deeplearning4j.nn.conf.layers.DenseLayer;

  29. import org.deeplearning4j.nn.conf.layers.OutputLayer;

  30. import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;

  31. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

  32. import org.deeplearning4j.nn.weights.WeightInit;

  33. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;

  34. import org.deeplearning4j.util.ModelSerializer;

  35. import org.nd4j.evaluation.classification.Evaluation;

  36. import org.nd4j.linalg.activations.Activation;

  37. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

  38. import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;

  39. import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;

  40. import org.nd4j.linalg.learning.config.Nesterovs;

  41. import org.nd4j.linalg.lossfunctions.LossFunctions;

  42. import org.nd4j.linalg.schedule.MapSchedule;

  43. import org.nd4j.linalg.schedule.ScheduleType;

  44. import java.io.File;

  45. import java.util.HashMap;

  46. import java.util.Map;

  47. import java.util.Random;

  48. /**

  49. * Implementation of LeNet-5 for handwritten digits image classification on MNIST dataset (99% accuracy)

  50. * <a href="http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf">[LeCun et al., 1998. Gradient based learning applied to document recognition]</a>

  51. * Some minor changes are made to the architecture like using ReLU and identity activation instead of

  52. * sigmoid/tanh, max pooling instead of avg pooling and softmax output layer.

  53. * <p>

  54. * This example will download 15 Mb of data on the first run.

  55. *

  56. * @author hanlon

  57. * @author agibsonccc

  58. * @author fvaleri

  59. * @author dariuszzbyrad

  60. */

  61. @Slf4j

  62. public class LeNetMNISTReLu {

  63. //dataset github:https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz

  64. // 存放文件的地址,请酌情修改

  65. // private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";

  66. private static final String BASE_PATH = "/Users/liuhaihua/Downloads";

  67. public static void main(String[] args) throws Exception {

  68. // 图片像素高

  69. int height = 28;

  70. // 图片像素宽

  71. int width = 28;

  72. // 因为是黑白图像,所以颜色通道只有一个

  73. int channels = 1;

  74. // 分类结果,0-9,共十种数字

  75. int outputNum = 10;

  76. // 批大小

  77. int batchSize = 54;

  78. // 循环次数

  79. int nEpochs = 1;

  80. // 初始化伪随机数的种子

  81. int seed = 1234;

  82. // 随机数工具

  83. Random randNumGen = new Random(seed);

  84. log.info("检查数据集文件夹是否存在:{}", BASE_PATH + "/mnist_png");

  85. if (!new File(BASE_PATH + "/mnist_png").exists()) {

  86. log.info("数据集文件不存在,请下载压缩包并解压到:{}", BASE_PATH);

  87. return;

  88. }

  89. // 标签生成器,将指定文件的父目录作为标签

  90. ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

  91. // 归一化配置(像素值从0-255变为0-1)

  92. DataNormalization imageScaler = new ImagePreProcessingScaler();

  93. // 不论训练集还是测试集,初始化操作都是相同套路:

  94. // 1. 读取图片,数据格式为NCHW

  95. // 2. 根据批大小创建的迭代器

  96. // 3. 将归一化器作为预处理器

  97. log.info("训练集的矢量化操作...");

  98. // 初始化训练集

  99. File trainData = new File(BASE_PATH + "/mnist_png/training");

  100. FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);

  101. ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);

  102. trainRR.initialize(trainSplit);

  103. DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);

  104. // 拟合数据(实现类中实际上什么也没做)

  105. imageScaler.fit(trainIter);

  106. trainIter.setPreProcessor(imageScaler);

  107. log.info("测试集的矢量化操作...");

  108. // 初始化测试集,与前面的训练集操作类似

  109. File testData = new File(BASE_PATH + "/mnist_png/testing");

  110. FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);

  111. ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);

  112. testRR.initialize(testSplit);

  113. DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);

  114. testIter.setPreProcessor(imageScaler); // same normalization for better results

  115. log.info("配置神经网络");

  116. // 在训练中,将学习率配置为随着迭代阶梯性下降

  117. Map<Integer, Double> learningRateSchedule = new HashMap<>();

  118. learningRateSchedule.put(0, 0.06);

  119. learningRateSchedule.put(200, 0.05);

  120. learningRateSchedule.put(600, 0.028);

  121. learningRateSchedule.put(800, 0.0060);

  122. learningRateSchedule.put(1000, 0.001);

  123. // 超参数

  124. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

  125. .seed(seed)

  126. // L2正则化系数

  127. .l2(0.0005)

  128. // 梯度下降的学习率设置

  129. .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))

  130. // 权重初始化

  131. .weightInit(WeightInit.XAVIER)

  132. // 准备分层

  133. .list()

  134. // 卷积层

  135. .layer(new ConvolutionLayer.Builder(5, 5)

  136. .nIn(channels)

  137. .stride(1, 1)

  138. .nOut(20)

  139. .activation(Activation.IDENTITY)

  140. .build())

  141. // 下采样,即池化

  142. .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)

  143. .kernelSize(2, 2)

  144. .stride(2, 2)

  145. .build())

  146. // 卷积层

  147. .layer(new ConvolutionLayer.Builder(5, 5)

  148. .stride(1, 1) // nIn need not specified in later layers

  149. .nOut(50)

  150. .activation(Activation.IDENTITY)

  151. .build())

  152. // 下采样,即池化

  153. .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)

  154. .kernelSize(2, 2)

  155. .stride(2, 2)

  156. .build())

  157. // 稠密层,即全连接

  158. .layer(new DenseLayer.Builder().activation(Activation.RELU)

  159. .nOut(500)

  160. .build())

  161. // 输出

  162. .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)

  163. .nOut(outputNum)

  164. .activation(Activation.SOFTMAX)

  165. .build())

  166. .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image

  167. .build();

  168. MultiLayerNetwork net = new MultiLayerNetwork(conf);

  169. net.init();

  170. // 每十个迭代打印一次损失函数值

  171. net.setListeners(new ScoreIterationListener(10));

  172. log.info("神经网络共[{}]个参数", net.numParams());

  173. long startTime = System.currentTimeMillis();

  174. // 循环操作

  175. for (int i = 0; i < nEpochs; i++) {

  176. log.info("第[{}]个循环", i);

  177. net.fit(trainIter);

  178. Evaluation eval = net.evaluate(testIter);

  179. log.info(eval.stats());

  180. trainIter.reset();

  181. testIter.reset();

  182. }

  183. log.info("完成训练和测试,耗时[{}]毫秒", System.currentTimeMillis()-startTime);

  184. // 保存模型

  185. File ministModelPath = new File(BASE_PATH + "/minist-model.zip");

  186. ModelSerializer.writeModel(net, ministModelPath, true);

  187. log.info("最新的MINIST模型保存在[{}]", ministModelPath.getPath());

  188. }

  189. }

输出模型文件和得分结果

dl2

3.编写模型预测接口

pom.xml
 
  1. <?xml version="1.0" encoding="UTF-8"?>

  2. <project xmlns="http://maven.apache.org/POM/4.0.0"

  3. xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"

  4. xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">

  5. <parent>

  6. <artifactId>springboot-demo</artifactId>

  7. <groupId>com.et</groupId>

  8. <version>1.0-SNAPSHOT</version>

  9. </parent>

  10. <modelVersion>4.0.0</modelVersion>

  11. <artifactId>Deeplearning4j</artifactId>

  12. <properties>

  13. <maven.compiler.source>8</maven.compiler.source>

  14. <maven.compiler.target>8</maven.compiler.target>

  15. <dl4j-master.version>1.0.0-beta7</dl4j-master.version>

  16. <nd4j.backend>nd4j-native</nd4j.backend>

  17. </properties>

  18. <dependencies>

  19. <dependency>

  20. <groupId>org.springframework.boot</groupId>

  21. <artifactId>spring-boot-starter-web</artifactId>

  22. </dependency>

  23. <dependency>

  24. <groupId>org.springframework.boot</groupId>

  25. <artifactId>spring-boot-autoconfigure</artifactId>

  26. </dependency>

  27. <dependency>

  28. <groupId>org.springframework.boot</groupId>

  29. <artifactId>spring-boot-starter-test</artifactId>

  30. <scope>test</scope>

  31. </dependency>

  32. <dependency>

  33. <groupId>org.projectlombok</groupId>

  34. <artifactId>lombok</artifactId>

  35. <version>1.18.20</version>

  36. </dependency>

  37. <dependency>

  38. <groupId>ch.qos.logback</groupId>

  39. <artifactId>logback-classic</artifactId>

  40. </dependency>

  41. <dependency>

  42. <groupId>org.deeplearning4j</groupId>

  43. <artifactId>deeplearning4j-core</artifactId>

  44. <version>${dl4j-master.version}</version>

  45. </dependency>

  46. <dependency>

  47. <groupId>org.nd4j</groupId>

  48. <artifactId>${nd4j.backend}</artifactId>

  49. <version>${dl4j-master.version}</version>

  50. </dependency>

  51. <!--用于本地GPU-->

  52. <!-- <dependency>-->

  53. <!-- <groupId>org.deeplearning4j</groupId>-->

  54. <!-- <artifactId>deeplearning4j-cuda-9.2</artifactId>-->

  55. <!-- <version>${dl4j-master.version}</version>-->

  56. <!-- </dependency>-->

  57. <!-- <dependency>-->

  58. <!-- <groupId>org.nd4j</groupId>-->

  59. <!-- <artifactId>nd4j-cuda-9.2-platform</artifactId>-->

  60. <!-- <version>${dl4j-master.version}</version>-->

  61. <!-- </dependency>-->

  62. </dependencies>

  63. </project>

cotroller
 
  1. package com.et.dl4j.controller;

  2. import com.et.dl4j.service.PredictService;

  3. import org.springframework.beans.factory.annotation.Autowired;

  4. import org.springframework.web.bind.annotation.*;

  5. import org.springframework.web.multipart.MultipartFile;

  6. import java.util.HashMap;

  7. import java.util.Map;

  8. @RestController

  9. public class HelloWorldController {

  10. @RequestMapping("/hello")

  11. public Map<String, Object> showHelloWorld(){

  12. Map<String, Object> map = new HashMap<>();

  13. map.put("msg", "HelloWorld");

  14. return map;

  15. }

  16. @Autowired

  17. PredictService predictService;

  18. @PostMapping("/predict-with-black-background")

  19. public int predictWithBlackBackground(@RequestParam("file") MultipartFile file) throws Exception {

  20. // 训练模型的时候,用的数字是白字黑底,

  21. // 因此如果上传白字黑底的图片,可以直接拿去识别,而无需反色处理

  22. return predictService.predict(file, false);

  23. }

  24. @PostMapping("/predict-with-white-background")

  25. public int predictWithWhiteBackground(@RequestParam("file") MultipartFile file) throws Exception {

  26. // 训练模型的时候,用的数字是白字黑底,

  27. // 因此如果上传黑字白底的图片,就需要做反色处理,

  28. // 反色之后就是白字黑底了,可以拿去识别

  29. return predictService.predict(file, true);

  30. }

  31. }

service
 
  1. package com.et.dl4j.service;

  2. import org.springframework.web.multipart.MultipartFile;

  3. public interface PredictService {

  4. /**

  5. * 取得上传的图片,做转换后识别成数字

  6. * @param file 上传的文件

  7. * @param isNeedRevert 是否要做反色处理

  8. * @return

  9. */

  10. int predict(MultipartFile file, boolean isNeedRevert) throws Exception ;

  11. }

 
  1. package com.et.dl4j.service.impl;

  2. import com.et.dl4j.service.PredictService;

  3. import com.et.dl4j.util.ImageFileUtil;

  4. import lombok.extern.slf4j.Slf4j;

  5. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

  6. import org.deeplearning4j.util.ModelSerializer;

  7. import org.nd4j.linalg.api.ndarray.INDArray;

  8. import org.springframework.beans.factory.annotation.Value;

  9. import org.springframework.stereotype.Service;

  10. import org.springframework.web.multipart.MultipartFile;

  11. import javax.annotation.PostConstruct;

  12. import java.io.File;

  13. @Service

  14. @Slf4j

  15. public class PredictServiceImpl implements PredictService {

  16. /**

  17. * -1表示识别失败

  18. */

  19. private static final int RLT_INVALID = -1;

  20. /**

  21. * 模型文件的位置

  22. */

  23. @Value("${predict.modelpath}")

  24. private String modelPath;

  25. /**

  26. * 处理图片文件的目录

  27. */

  28. @Value("${predict.imagefilepath}")

  29. private String imageFilePath;

  30. /**

  31. * 神经网络

  32. */

  33. private MultiLayerNetwork net;

  34. /**

  35. * bean实例化成功就加载模型

  36. */

  37. @PostConstruct

  38. private void loadModel() {

  39. log.info("load model from [{}]", modelPath);

  40. // 加载模型

  41. try {

  42. net = ModelSerializer.restoreMultiLayerNetwork(new File(modelPath));

  43. log.info("module summary\n{}", net.summary());

  44. } catch (Exception exception) {

  45. log.error("loadModel error", exception);

  46. }

  47. }

  48. @Override

  49. public int predict(MultipartFile file, boolean isNeedRevert) throws Exception {

  50. log.info("start predict, file [{}], isNeedRevert [{}]", file.getOriginalFilename(), isNeedRevert);

  51. // 先存文件

  52. String rawFileName = ImageFileUtil.save(imageFilePath, file);

  53. if (null==rawFileName) {

  54. return RLT_INVALID;

  55. }

  56. // 反色处理后的文件名

  57. String revertFileName = null;

  58. // 调整大小后的文件名

  59. String resizeFileName;

  60. // 是否需要反色处理

  61. if (isNeedRevert) {

  62. // 把原始文件做反色处理,返回结果是反色处理后的新文件

  63. revertFileName = ImageFileUtil.colorRevert(imageFilePath, rawFileName);

  64. // 把反色处理后调整为28*28大小的文件

  65. resizeFileName = ImageFileUtil.resize(imageFilePath, revertFileName);

  66. } else {

  67. // 直接把原始文件调整为28*28大小的文件

  68. resizeFileName = ImageFileUtil.resize(imageFilePath, rawFileName);

  69. }

  70. // 现在已经得到了结果反色和调整大小处理过后的文件,

  71. // 那么原始文件和反色处理过的文件就可以删除了

  72. ImageFileUtil.clear(imageFilePath, rawFileName, revertFileName);

  73. // 取出该黑白图片的特征

  74. INDArray features = ImageFileUtil.getGrayImageFeatures(imageFilePath, resizeFileName);

  75. // 将特征传给模型去识别

  76. return net.predict(features)[0];

  77. }

  78. }

application.properties
 
  1. # 上传文件总的最大值

  2. spring.servlet.multipart.max-request-size=1024MB

  3. # 单个文件的最大值

  4. spring.servlet.multipart.max-file-size=10MB

  5. # 处理图片文件的目录

  6. predict.imagefilepath=/Users/liuhaihua/Downloads/images/

  7. # 模型所在位置

  8. predict.modelpath=/Users/liuhaihua/Downloads/minist-model.zip

工具类
 
  1. package com.et.dl4j.util;

  2. import lombok.extern.slf4j.Slf4j;

  3. import org.datavec.api.split.FileSplit;

  4. import org.datavec.image.loader.NativeImageLoader;

  5. import org.datavec.image.recordreader.ImageRecordReader;

  6. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;

  7. import org.nd4j.linalg.api.ndarray.INDArray;

  8. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

  9. import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;

  10. import org.springframework.web.multipart.MultipartFile;

  11. import javax.imageio.ImageIO;

  12. import java.awt.*;

  13. import java.awt.image.BufferedImage;

  14. import java.io.File;

  15. import java.io.IOException;

  16. import java.util.UUID;

  17. @Slf4j

  18. public class ImageFileUtil {

  19. /**

  20. * 调整后的文件宽度

  21. */

  22. public static final int RESIZE_WIDTH = 28;

  23. /**

  24. * 调整后的文件高度

  25. */

  26. public static final int RESIZE_HEIGHT = 28;

  27. /**

  28. * 将上传的文件存在服务器上

  29. * @param base 要处理的文件所在的目录

  30. * @param file 要处理的文件

  31. * @return

  32. */

  33. public static String save(String base, MultipartFile file) {

  34. // 检查是否为空

  35. if (file.isEmpty()) {

  36. log.error("invalid file");

  37. return null;

  38. }

  39. // 文件名来自原始文件

  40. String fileName = file.getOriginalFilename();

  41. // 要保存的位置

  42. File dest = new File(base + fileName);

  43. // 开始保存

  44. try {

  45. file.transferTo(dest);

  46. } catch (IOException e) {

  47. log.error("upload fail", e);

  48. return null;

  49. }

  50. return fileName;

  51. }

  52. /**

  53. * 将图片转为28*28像素

  54. * @param base 处理文件的目录

  55. * @param fileName 待调整的文件名

  56. * @return

  57. */

  58. public static String resize(String base, String fileName) {

  59. // 新文件名是原文件名在加个随机数后缀,而且扩展名固定为png

  60. String resizeFileName = fileName.substring(0, fileName.lastIndexOf(".")) + "-" + UUID.randomUUID() + ".png";

  61. log.info("start resize, from [{}] to [{}]", fileName, resizeFileName);

  62. try {

  63. // 读原始文件

  64. BufferedImage bufferedImage = ImageIO.read(new File(base + fileName));

  65. // 缩放后的实例

  66. Image image = bufferedImage.getScaledInstance(RESIZE_WIDTH, RESIZE_HEIGHT, Image.SCALE_SMOOTH);

  67. BufferedImage resizeBufferedImage = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);

  68. Graphics graphics = resizeBufferedImage.getGraphics();

  69. // 绘图

  70. graphics.drawImage(image, 0, 0, null);

  71. graphics.dispose();

  72. // 转换后的图片写文件

  73. ImageIO.write(resizeBufferedImage, "png", new File(base + resizeFileName));

  74. } catch (Exception exception) {

  75. log.info("resize error from [{}] to [{}], {}", fileName, resizeFileName, exception);

  76. resizeFileName = null;

  77. }

  78. log.info("finish resize, from [{}] to [{}]", fileName, resizeFileName);

  79. return resizeFileName;

  80. }

  81. /**

  82. * 将RGB转为int数字

  83. * @param alpha

  84. * @param red

  85. * @param green

  86. * @param blue

  87. * @return

  88. */

  89. private static int colorToRGB(int alpha, int red, int green, int blue) {

  90. int pixel = 0;

  91. pixel += alpha;

  92. pixel = pixel << 8;

  93. pixel += red;

  94. pixel = pixel << 8;

  95. pixel += green;

  96. pixel = pixel << 8;

  97. pixel += blue;

  98. return pixel;

  99. }

  100. /**

  101. * 反色处理

  102. * @param base 处理文件的目录

  103. * @param src 用于处理的源文件

  104. * @return 反色处理后的新文件

  105. * @throws IOException

  106. */

  107. public static String colorRevert(String base, String src) throws IOException {

  108. int color, r, g, b, pixel;

  109. // 读原始文件

  110. BufferedImage srcImage = ImageIO.read(new File(base + src));

  111. // 修改后的文件

  112. BufferedImage destImage = new BufferedImage(srcImage.getWidth(), srcImage.getHeight(), srcImage.getType());

  113. for (int i=0; i<srcImage.getWidth(); i++) {

  114. for (int j=0; j<srcImage.getHeight(); j++) {

  115. color = srcImage.getRGB(i, j);

  116. r = (color >> 16) & 0xff;

  117. g = (color >> 8) & 0xff;

  118. b = color & 0xff;

  119. pixel = colorToRGB(255, 0xff - r, 0xff - g, 0xff - b);

  120. destImage.setRGB(i, j, pixel);

  121. }

  122. }

  123. // 反射文件的名字

  124. String revertFileName = src.substring(0, src.lastIndexOf(".")) + "-revert.png";

  125. // 转换后的图片写文件

  126. ImageIO.write(destImage, "png", new File(base + revertFileName));

  127. return revertFileName;

  128. }

  129. /**

  130. * 取黑白图片的特征

  131. * @param base

  132. * @param fileName

  133. * @return

  134. * @throws Exception

  135. */

  136. public static INDArray getGrayImageFeatures(String base, String fileName) throws Exception {

  137. log.info("start getImageFeatures [{}]", base + fileName);

  138. // 和训练模型时一样的设置

  139. ImageRecordReader imageRecordReader = new ImageRecordReader(RESIZE_HEIGHT, RESIZE_WIDTH, 1);

  140. FileSplit fileSplit = new FileSplit(new File(base + fileName),

  141. NativeImageLoader.ALLOWED_FORMATS);

  142. imageRecordReader.initialize(fileSplit);

  143. DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(imageRecordReader, 1);

  144. dataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1));

  145. // 取特征

  146. return dataSetIterator.next().getFeatures();

  147. }

  148. /**

  149. * 批量清理文件

  150. * @param base 处理文件的目录

  151. * @param fileNames 待清理文件集合

  152. */

  153. public static void clear(String base, String...fileNames) {

  154. for (String fileName : fileNames) {

  155. if (null==fileName) {

  156. continue;

  157. }

  158. File file = new File(base + fileName);

  159. if (file.exists()) {

  160. file.delete();

  161. }

  162. }

  163. }

  164. }

DemoApplication.java
 
  1. package com.et.dl4j;

  2. import org.springframework.boot.SpringApplication;

  3. import org.springframework.boot.autoconfigure.SpringBootApplication;

  4. @SpringBootApplication

  5. public class DemoApplication {

  6. public static void main(String[] args) {

  7. SpringApplication.run(DemoApplication.class, args);

  8. }

  9. }

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

4.测试

启动Spring Boot应用,上传图片测试

  • 如果用户输入的是黑底白字的图片,只需要将上述流程中的反色处理去掉即可
  • 为白底黑字图片提供专用接口predict-with-white-background
  • 为黑底白字图片提供专用接口predict-with-black-background

dl3

5.引用

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值