引言
最近在写毕设,也在做一些机器学习相关的科研。
然后有一个想法一直在我心头萦绕,我们的科研和应用之间还存在着某些距离,就比如说我们实现了某种分类的网络,如果要将其作为应用开放出来,这个常常是比较困难的,往往需要专业人士才能够理解和二次使用它。或者更确切的说,它的包装程度还不够。
所以最近在想着如何将这些能力开放出来,于是就有了本节内容。
本节举一个最简单的例子,在本地IDEA运行springboot项目,通过网页访问http://localhost:8080实现网络训练,待训练结束返回训练结束标志。这个网络是deeplearning4j项目中的一个分类的例子:MLPClassifierMoon。
项目地址:https://github.com/xiaozhch5/spring-guides/tree/master/springboot-dl4j
项目创建
大家可以直接在我的github项目地址下载或者按照下面步骤进行。
首先通过IDEA创建一个springboot项目。
POM文件
pom.xml文件如下
<?xml version="1.0" encoding="UTF-8"?>4.0.0org.springframework.boot spring-boot-starter-parent 2.2.5.RELEASEcom.weilian springboot-dl4j 0.0.1-SNAPSHOTspringboot-dl4jDemo project for Spring Boot1.8nd4j-native-platformUTF-8bin1.81.0.0-beta61.0.0-beta61.0.0-beta6 1.0.0-beta6 1.0.0-beta62.112.4.32.2.019.01.1.71.0.131.0.233.6.12.4.31.4.03.3.12.2.3 1.11.109 2.5.13.2.2org.springframework.boot spring-boot-starter-web org.projectlombok lombok trueorg.springframework.boot spring-boot-starter-test testorg.junit.vintage junit-vintage-engine org.nd4j ${nd4j.backend} ${nd4j.version}org.deeplearning4j deeplearning4j-nlp ${dl4j.version}org.deeplearning4j deeplearning4j-zoo ${dl4j.version}org.deeplearning4j deeplearning4j-ui ${dl4j.version}org.deeplearning4j deeplearning4j-parallel-wrapper ${dl4j.version}org.datavec datavec-hadoop ${datavec.version}org.apache.hadoop hadoop-common ${hadoop.version}jdk.tools jdk.tools log4j log4j org.slf4j slf4j-log4j12 org.deeplearning4j arbiter-deeplearning4j ${arbiter.version}org.deeplearning4j arbiter-ui ${arbiter.version} datavec-data-codec org.datavec${datavec.version}jfree jfreechart ${jfreechart.version}org.jfree jcommon ${jcommon.version}org.apache.httpcomponents httpclient 4.3.6org.deeplearning4j.examples shared-utilities ch.qos.logback logback-classic ${logback.version}org.springframework.boot spring-boot-maven-plugin
controller文件创建
创建一个controller类,com/weilian/springbootdl4j/controller/Dl4jController.java,代码如下:
package com.weilian.springbootdl4j.controller;import com.weilian.springbootdl4j.service.MLPClassifierMoon;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RestController;@RestControllerpublic class Dl4jController { @GetMapping("/") public String dl4jResult() throws Exception{ return new MLPClassifierMoon().run(); }}
service文件创建
创建com/weilian/springbootdl4j/service/DownloaderUtility.java,代码如下:
package com.weilian.springbootdl4j.service;import org.apache.commons.io.FilenameUtils;import org.nd4j.resources.Downloader;import java.io.File;import java.net.URL;public enum DownloaderUtility { /** Skymind datavec resources stored under AZURE_BLOB_URL/datavec-examples */ BASICDATAVECEXAMPLE("BasicDataVecExample.zip", "datavec-examples", "92f87e0ceb81093ff8b49e2b4e0a5a02", "1KB"), INPUTSPLIT("inputsplit.zip", "datavec-examples", "f316b5274bab3b0f568eded9bee1c67f", "128KB"), IRISDATA("IrisData.zip", "datavec-examples", "bb49e38bb91089634d7ef37ad8e430b8", "1KB"), JOINEXAMPLE("JoinExample.zip", "datavec-examples", "cbd6232cf1463d68ff24807d5dd8b530", "1KB"), /** Skymind dl4j-examples resources stored under AZURE_BLOB_URL/dl4j-examples */ ANIMALS("animals.zip", "dl4j-examples", "1976a1f2b61191d2906e4f615246d63e", "820KB"), ANOMALYSEQUENCEDATA("anomalysequencedata.zip", "dl4j-examples", "51bb7c50e265edec3a241a2d7cce0e73", "3MB"), CAPTCHAIMAGE("captchaImage.zip", "dl4j-examples", "1d159c9587fdbb1cbfd66f0d62380e61", "42MB"), CLASSIFICATIONDATA("classification.zip", "dl4j-examples", "dba31e5838fe15993579edbf1c60c355", "77KB"), DATAEXAMPLES("DataExamples.zip", "dl4j-examples", "e4de9c6f19aaae21fed45bfe2a730cbb", "2MB"), LOTTERYDATA("lottery.zip", "dl4j-examples", "1e54ac1210e39c948aa55417efee193a", "2MB"), MODELIMPORT("modelimport.zip", "dl4j-examples", "411df05aace1c9ff587e430a662ce621", "3MB"), NEWSDATA("NewsData.zip", "dl4j-examples", "0d08e902faabe6b8bfe5ecdd78af9f64", "21MB"), NLPDATA("nlp.zip", "dl4j-examples", "1ac7cd7ca08f13402f0e3b83e20c0512", "91MB"), PREDICTGENDERDATA("PredictGender.zip", "dl4j-examples", "42a3fec42afa798217e0b8687667257e", "3MB"), STYLETRANSFER("styletransfer.zip", "dl4j-examples", "b2b90834d667679d7ee3dfb1f40abe94", "3MB"), //This download is handled a little differently since the zip is not a single directory but a bunch of stuff at the top level BERTEXAMPLE("https://dl4jdata.blob.core.windows.net/testresources", "bert_mrpc_frozen_v1.zip", "bert-frozen-example", "7cef8bbe62e701212472f77a0361f443", "420MB"), /** Skymind tf-import-examples resources stored under AZURE_BLOB_URL/tf-import-examples */ TFIMPORTEXAMPLES("resources.zip", "tf-import-examples", "4895e40e71b17799e4d6fb75d5a22491", "3MB"), /** Skymind dl4j-spark example resources stored under AZURE_BLOB_URL/dl4j-spark-examples */ PATENTEXAMPLE("patentExample.zip", "dl4j-spark-examples", "435e2b814d866550678d2ac4d8cc5423", "10KB"); private final String BASE_URL; private final String DATA_FOLDER; private final String ZIP_FILE; private final String MD5; private final String DATA_SIZE; private static final String AZURE_BLOB_URL = "https://dl4jdata.blob.core.windows.net/dl4j-examples"; /** * For use with resources uploaded to Azure blob storage. * * @param zipFile Name of zipfile. Should be a zip of a single directory with the same name * @param dataFolder The folder to extract to under ~/dl4j-examples-data * @param md5 of zipfile * @param dataSize of zipfile */ DownloaderUtility(String zipFile, String dataFolder, String md5, String dataSize) { this(AZURE_BLOB_URL + "/" + dataFolder, zipFile, dataFolder, md5, dataSize); } /** * Downloads a zip file from a base url to a specified directory under the user's home directory * * @param baseURL URL of file * @param zipFile Name of zipfile to download from baseURL i.e baseURL+"/"+zipFile gives full URL * @param dataFolder The folder to extract to under ~/dl4j-examples-data * @param md5 of zipfile * @param dataSize of zipfile */ DownloaderUtility(String baseURL, String zipFile, String dataFolder, String md5, String dataSize) { BASE_URL = baseURL; DATA_FOLDER = dataFolder; ZIP_FILE = zipFile; MD5 = md5; DATA_SIZE = dataSize; } public String Download() throws Exception { return Download(true); } public String Download(boolean returnSubFolder) throws Exception { String dataURL = BASE_URL + "/" + ZIP_FILE; String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), ZIP_FILE); String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + DATA_FOLDER); if (!new File(extractDir).exists()) new File(extractDir).mkdirs(); String dataPathLocal = extractDir; if (returnSubFolder) { String resourceName = ZIP_FILE.substring(0, ZIP_FILE.lastIndexOf(".zip")); dataPathLocal = FilenameUtils.concat(extractDir, resourceName); } int downloadRetries = 10; if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) { System.out.println("_______________________________________________________________________"); System.out.println("Downloading data (" + DATA_SIZE + ") and extracting to " + dataPathLocal); System.out.println("_______________________________________________________________________"); Downloader.downloadAndExtract("files", new URL(dataURL), new File(downloadPath), new File(extractDir), MD5, downloadRetries); } else { System.out.println("_______________________________________________________________________"); System.out.println("Example data present in " + dataPathLocal); System.out.println("_______________________________________________________________________"); } return dataPathLocal; }}
创建com/weilian/springbootdl4j/service/MLClassifierMoon.java,代码如下:
package com.weilian.springbootdl4j.service;import org.datavec.api.records.reader.RecordReader;import org.datavec.api.records.reader.impl.csv.CSVRecordReader;import org.datavec.api.split.FileSplit;import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;import org.deeplearning4j.nn.conf.MultiLayerConfiguration;import org.deeplearning4j.nn.conf.NeuralNetConfiguration;import org.deeplearning4j.nn.conf.layers.DenseLayer;import org.deeplearning4j.nn.conf.layers.OutputLayer;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;import org.deeplearning4j.nn.weights.WeightInit;import org.deeplearning4j.optimize.listeners.ScoreIterationListener;import org.nd4j.evaluation.classification.Evaluation;import org.nd4j.linalg.activations.Activation;import org.nd4j.linalg.api.ndarray.INDArray;import org.nd4j.linalg.dataset.DataSet;import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;import org.nd4j.linalg.factory.Nd4j;import org.nd4j.linalg.learning.config.Nesterovs;import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;import org.springframework.beans.factory.annotation.Autowired;import javax.annotation.Resource;import java.io.File;/** * "Moon" Data Classification Example * * Based on the data from Jason Baldridge: * https://github.com/jasonbaldridge/try-tf/tree/master/simdata * * @author Josh Patterson * @author Alex Black (added plots) * */@SuppressWarnings("DuplicatedCode")public class MLPClassifierMoon { public static String dataLocalPath; public String run() throws Exception { int seed = 123; double learningRate = 0.005; int batchSize = 50; int nEpochs = 100; int numInputs = 2; int numOutputs = 2; int numHiddenNodes = 50; dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download(); //Load the training data: RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new File(dataLocalPath,"moon_data_train.csv"))); DataSetIterator trainIter = new RecordReaderDataSetIterator(rr,batchSize,0,2); //Load the test/evaluation data: RecordReader rrTest = new CSVRecordReader(); rrTest.initialize(new FileSplit(new File(dataLocalPath,"moon_data_eval.csv"))); DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,batchSize,0,2); //log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(learningRate, 0.9)) .list() .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .activation(Activation.RELU) .build()) .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) .weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX) .nIn(numHiddenNodes).nOut(numOutputs).build()) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(100)); //Print score every 100 parameter updates model.fit( trainIter, nEpochs ); System.out.println("Evaluate model...."); Evaluation eval = model.evaluate(testIter); //Print the evaluation statistics System.out.println(eval.stats()); //------------------------------------------------------------------------------------ //Training is complete. Code that follows is for plotting the data & predictions only //Plot the data double xMin = -1.5; double xMax = 2.5; double yMin = -1; double yMax = 1.5; //Let's evaluate the predictions at every point in the x/y input space, and plot this in the background int nPointsPerAxis = 100; double[][] evalPoints = new double[nPointsPerAxis*nPointsPerAxis][2]; int count = 0; for( int i=0; i
IDEA运行
代码完成之后,在IDEA中运行SpringbootDl4jApplication.java类,可以看到本项目监听本地8080端口(如果没有重新配置端口的话)
本地浏览器访问:http://localhost:8080
可以看到IDEA已经在训练网络,并且控制台输出:
训练结束后,浏览器打印出train finished。
至此,我们简单实现了通过springboot接口去实现网络训练api了。
当然,通过springboot和机器学习项目结合可以实现非常多的功能,之后我会和大家继续分享相关知识哦,欢迎大家关注我哦。