javafx项目_是时候通过springboot把你的机器学习项目通过API形式开放出来了

引言

最近在写毕设,也在做一些机器学习相关的科研。

然后有一个想法一直在我心头萦绕,我们的科研和应用之间还存在着某些距离,就比如说我们实现了某种分类的网络,如果要将其作为应用开放出来,这个常常是比较困难的,往往需要专业人士才能够理解和二次使用它。或者更确切的说,它的包装程度还不够。

所以最近在想着如何将这些能力开放出来,于是就有了本节内容。

本节举一个最简单的例子,在本地IDEA运行springboot项目,通过网页访问http://localhost:8080实现网络训练,待训练结束返回训练结束标志。这个网络是deeplearning4j项目中的一个分类的例子:MLPClassifierMoon

项目地址:https://github.com/xiaozhch5/spring-guides/tree/master/springboot-dl4j

项目创建

大家可以直接在我的github项目地址下载或者按照下面步骤进行。

首先通过IDEA创建一个springboot项目。

a646e3e8dd4942220e978157609c5681.png
83f5885bd19ab974d3ea39eca465dc7b.png
428c7ac10a6f68554eb7b31ffd405663.png

POM文件

pom.xml文件如下

<?xml version="1.0" encoding="UTF-8"?>4.0.0org.springframework.boot        spring-boot-starter-parent        2.2.5.RELEASEcom.weilian    springboot-dl4j    0.0.1-SNAPSHOTspringboot-dl4jDemo project for Spring Boot1.8nd4j-native-platformUTF-8bin1.81.0.0-beta61.0.0-beta61.0.0-beta6        1.0.0-beta6        1.0.0-beta62.112.4.32.2.019.01.1.71.0.131.0.233.6.12.4.31.4.03.3.12.2.3        1.11.109        2.5.13.2.2org.springframework.boot            spring-boot-starter-web        org.projectlombok            lombok            trueorg.springframework.boot            spring-boot-starter-test            testorg.junit.vintage                    junit-vintage-engine                org.nd4j            ${nd4j.backend}            ${nd4j.version}org.deeplearning4j            deeplearning4j-nlp            ${dl4j.version}org.deeplearning4j            deeplearning4j-zoo            ${dl4j.version}org.deeplearning4j            deeplearning4j-ui            ${dl4j.version}org.deeplearning4j            deeplearning4j-parallel-wrapper            ${dl4j.version}org.datavec            datavec-hadoop            ${datavec.version}org.apache.hadoop            hadoop-common            ${hadoop.version}jdk.tools                    jdk.tools                log4j                    log4j                org.slf4j                    slf4j-log4j12                org.deeplearning4j            arbiter-deeplearning4j            ${arbiter.version}org.deeplearning4j            arbiter-ui            ${arbiter.version}            datavec-data-codec            org.datavec${datavec.version}jfree            jfreechart            ${jfreechart.version}org.jfree            jcommon            ${jcommon.version}org.apache.httpcomponents            httpclient            4.3.6org.deeplearning4j.examples            shared-utilities        ch.qos.logback            logback-classic            ${logback.version}org.springframework.boot                spring-boot-maven-plugin            

controller文件创建

创建一个controller类,com/weilian/springbootdl4j/controller/Dl4jController.java,代码如下:

package com.weilian.springbootdl4j.controller;import com.weilian.springbootdl4j.service.MLPClassifierMoon;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RestController;@RestControllerpublic class Dl4jController {    @GetMapping("/")    public String dl4jResult() throws Exception{        return new MLPClassifierMoon().run();    }}

service文件创建

创建com/weilian/springbootdl4j/service/DownloaderUtility.java,代码如下:

package com.weilian.springbootdl4j.service;import org.apache.commons.io.FilenameUtils;import org.nd4j.resources.Downloader;import java.io.File;import java.net.URL;public enum DownloaderUtility {    /**        Skymind datavec resources stored under AZURE_BLOB_URL/datavec-examples     */    BASICDATAVECEXAMPLE("BasicDataVecExample.zip", "datavec-examples", "92f87e0ceb81093ff8b49e2b4e0a5a02", "1KB"),    INPUTSPLIT("inputsplit.zip", "datavec-examples", "f316b5274bab3b0f568eded9bee1c67f", "128KB"),    IRISDATA("IrisData.zip", "datavec-examples", "bb49e38bb91089634d7ef37ad8e430b8", "1KB"),    JOINEXAMPLE("JoinExample.zip", "datavec-examples", "cbd6232cf1463d68ff24807d5dd8b530", "1KB"),    /**        Skymind dl4j-examples resources stored under AZURE_BLOB_URL/dl4j-examples     */    ANIMALS("animals.zip", "dl4j-examples", "1976a1f2b61191d2906e4f615246d63e", "820KB"),    ANOMALYSEQUENCEDATA("anomalysequencedata.zip", "dl4j-examples", "51bb7c50e265edec3a241a2d7cce0e73", "3MB"),    CAPTCHAIMAGE("captchaImage.zip", "dl4j-examples", "1d159c9587fdbb1cbfd66f0d62380e61", "42MB"),    CLASSIFICATIONDATA("classification.zip", "dl4j-examples", "dba31e5838fe15993579edbf1c60c355", "77KB"),    DATAEXAMPLES("DataExamples.zip", "dl4j-examples", "e4de9c6f19aaae21fed45bfe2a730cbb", "2MB"),    LOTTERYDATA("lottery.zip", "dl4j-examples", "1e54ac1210e39c948aa55417efee193a", "2MB"),    MODELIMPORT("modelimport.zip", "dl4j-examples", "411df05aace1c9ff587e430a662ce621", "3MB"),    NEWSDATA("NewsData.zip", "dl4j-examples", "0d08e902faabe6b8bfe5ecdd78af9f64", "21MB"),    NLPDATA("nlp.zip", "dl4j-examples", "1ac7cd7ca08f13402f0e3b83e20c0512", "91MB"),    PREDICTGENDERDATA("PredictGender.zip", "dl4j-examples", "42a3fec42afa798217e0b8687667257e", "3MB"),    STYLETRANSFER("styletransfer.zip", "dl4j-examples", "b2b90834d667679d7ee3dfb1f40abe94", "3MB"),    //This download is handled a little differently since the zip is not a single directory but a bunch of stuff at the top level    BERTEXAMPLE("https://dl4jdata.blob.core.windows.net/testresources", "bert_mrpc_frozen_v1.zip", "bert-frozen-example", "7cef8bbe62e701212472f77a0361f443", "420MB"),    /**        Skymind tf-import-examples resources stored under AZURE_BLOB_URL/tf-import-examples     */    TFIMPORTEXAMPLES("resources.zip", "tf-import-examples", "4895e40e71b17799e4d6fb75d5a22491", "3MB"),    /**        Skymind dl4j-spark example resources stored under AZURE_BLOB_URL/dl4j-spark-examples     */    PATENTEXAMPLE("patentExample.zip", "dl4j-spark-examples", "435e2b814d866550678d2ac4d8cc5423", "10KB");    private final String BASE_URL;    private final String DATA_FOLDER;    private final String ZIP_FILE;    private final String MD5;    private final String DATA_SIZE;    private static final String AZURE_BLOB_URL = "https://dl4jdata.blob.core.windows.net/dl4j-examples";    /**     * For use with resources uploaded to Azure blob storage.     *     * @param zipFile    Name of zipfile. Should be a zip of a single directory with the same name     * @param dataFolder The folder to extract to under ~/dl4j-examples-data     * @param md5        of zipfile     * @param dataSize   of zipfile     */    DownloaderUtility(String zipFile, String dataFolder, String md5, String dataSize) {        this(AZURE_BLOB_URL + "/" + dataFolder, zipFile, dataFolder, md5, dataSize);    }    /**     * Downloads a zip file from a base url to a specified directory under the user's home directory     *     * @param baseURL    URL of file     * @param zipFile    Name of zipfile to download from baseURL i.e baseURL+"/"+zipFile gives full URL     * @param dataFolder The folder to extract to under ~/dl4j-examples-data     * @param md5        of zipfile     * @param dataSize   of zipfile     */    DownloaderUtility(String baseURL, String zipFile, String dataFolder, String md5, String dataSize) {        BASE_URL = baseURL;        DATA_FOLDER = dataFolder;        ZIP_FILE = zipFile;        MD5 = md5;        DATA_SIZE = dataSize;    }    public String Download() throws Exception {        return Download(true);    }    public String Download(boolean returnSubFolder) throws Exception {        String dataURL = BASE_URL + "/" + ZIP_FILE;        String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), ZIP_FILE);        String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + DATA_FOLDER);        if (!new File(extractDir).exists())            new File(extractDir).mkdirs();        String dataPathLocal = extractDir;        if (returnSubFolder) {            String resourceName = ZIP_FILE.substring(0, ZIP_FILE.lastIndexOf(".zip"));            dataPathLocal = FilenameUtils.concat(extractDir, resourceName);        }        int downloadRetries = 10;        if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) {            System.out.println("_______________________________________________________________________");            System.out.println("Downloading data (" + DATA_SIZE + ") and extracting to " + dataPathLocal);            System.out.println("_______________________________________________________________________");            Downloader.downloadAndExtract("files",                new URL(dataURL),                new File(downloadPath),                new File(extractDir),                MD5,                downloadRetries);        } else {            System.out.println("_______________________________________________________________________");            System.out.println("Example data present in " + dataPathLocal);            System.out.println("_______________________________________________________________________");        }        return dataPathLocal;    }}

创建com/weilian/springbootdl4j/service/MLClassifierMoon.java,代码如下:

package com.weilian.springbootdl4j.service;import org.datavec.api.records.reader.RecordReader;import org.datavec.api.records.reader.impl.csv.CSVRecordReader;import org.datavec.api.split.FileSplit;import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;import org.deeplearning4j.nn.conf.MultiLayerConfiguration;import org.deeplearning4j.nn.conf.NeuralNetConfiguration;import org.deeplearning4j.nn.conf.layers.DenseLayer;import org.deeplearning4j.nn.conf.layers.OutputLayer;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;import org.deeplearning4j.nn.weights.WeightInit;import org.deeplearning4j.optimize.listeners.ScoreIterationListener;import org.nd4j.evaluation.classification.Evaluation;import org.nd4j.linalg.activations.Activation;import org.nd4j.linalg.api.ndarray.INDArray;import org.nd4j.linalg.dataset.DataSet;import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;import org.nd4j.linalg.factory.Nd4j;import org.nd4j.linalg.learning.config.Nesterovs;import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;import org.springframework.beans.factory.annotation.Autowired;import javax.annotation.Resource;import java.io.File;/** * "Moon" Data Classification Example * * Based on the data from Jason Baldridge: * https://github.com/jasonbaldridge/try-tf/tree/master/simdata * * @author Josh Patterson * @author Alex Black (added plots) * */@SuppressWarnings("DuplicatedCode")public class MLPClassifierMoon {    public static String dataLocalPath;    public String run() throws Exception {        int seed = 123;        double learningRate = 0.005;        int batchSize = 50;        int nEpochs = 100;        int numInputs = 2;        int numOutputs = 2;        int numHiddenNodes = 50;        dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();        //Load the training data:        RecordReader rr = new CSVRecordReader();        rr.initialize(new FileSplit(new File(dataLocalPath,"moon_data_train.csv")));        DataSetIterator trainIter = new RecordReaderDataSetIterator(rr,batchSize,0,2);        //Load the test/evaluation data:        RecordReader rrTest = new CSVRecordReader();        rrTest.initialize(new FileSplit(new File(dataLocalPath,"moon_data_eval.csv")));        DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);        //log.info("Build model....");        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()                .seed(seed)                .weightInit(WeightInit.XAVIER)                .updater(new Nesterovs(learningRate, 0.9))                .list()                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)                        .activation(Activation.RELU)                        .build())                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)                        .weightInit(WeightInit.XAVIER)                        .activation(Activation.SOFTMAX)                        .nIn(numHiddenNodes).nOut(numOutputs).build())                .build();        MultiLayerNetwork model = new MultiLayerNetwork(conf);        model.init();        model.setListeners(new ScoreIterationListener(100));    //Print score every 100 parameter updates        model.fit( trainIter, nEpochs );        System.out.println("Evaluate model....");        Evaluation eval = model.evaluate(testIter);        //Print the evaluation statistics        System.out.println(eval.stats());        //------------------------------------------------------------------------------------        //Training is complete. Code that follows is for plotting the data & predictions only        //Plot the data        double xMin = -1.5;        double xMax = 2.5;        double yMin = -1;        double yMax = 1.5;        //Let's evaluate the predictions at every point in the x/y input space, and plot this in the background        int nPointsPerAxis = 100;        double[][] evalPoints = new double[nPointsPerAxis*nPointsPerAxis][2];        int count = 0;        for( int i=0; i

IDEA运行

代码完成之后,在IDEA中运行SpringbootDl4jApplication.java类,可以看到本项目监听本地8080端口(如果没有重新配置端口的话)

672ea654681b17241e5d0040c33212bf.png

本地浏览器访问:http://localhost:8080

可以看到IDEA已经在训练网络,并且控制台输出:

ad4d0275738af2a4f497f1cba7b5619a.gif

训练结束后,浏览器打印出train finished。

568aef5b83a11c5bd2d949c28be17b2d.png

至此,我们简单实现了通过springboot接口去实现网络训练api了。

当然,通过springboot和机器学习项目结合可以实现非常多的功能,之后我会和大家继续分享相关知识哦,欢迎大家关注我哦。

d37d7a6c8a6468ca5c01ac017c5bcb9b.gif
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值