【DJL】Springboot+Maven+DJL实现java调取pytorch模型

一、前言

近期学习了DJL(深度java学习),有了一点小的研究成果,特以此博客分享给大家。这个技术是一个特别新的技术,是亚马逊云服务在2019年re:Invent大会推出的专为Java开发者量身定制的深度学习框架,网上的资料比较少,只有官方文档可以参考,研究起来难度比较大,但是经过不懈的努力,终于搞定了,接下来以官网的demo入门。由于这块有很多坑,所以有必要好好的说一下。

官网地址:https://docs.djl.ai/jupyter/load_pytorch_model.html

二、demo
1、创建SpringBoot项目,导入pom依赖
    <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.6.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.6.0</version>
            <scope>runtime</scope>
        </dependency>


        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.5.0</version>
            <scope>runtime</scope>
        </dependency>

        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>1.7.26</version>
        </dependency>

        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-simple</artifactId>
            <version>1.7.26</version>
        </dependency>

        <dependency>
            <groupId>net.java.dev.jna</groupId>
            <artifactId>jna</artifactId>
            <version>5.3.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <classifier>win-x86_64</classifier>
            <scope>runtime</scope>
            <version>1.5.0</version>
        </dependency>
 </dependencies>
2、下载模型
    DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
       

下载完成后生成一个build文件夹,里面有你的模型(如下下载失败,请翻墙,连接外网)
在这里插入图片描述

3、创建一个Translator
 Pipeline pipeline = new Pipeline();
        pipeline.add(new Resize(256))
                .add(new CenterCrop(224, 224))
                .add(new ToTensor())
                .add(new Normalize(
                        new float[]{0.485f, 0.456f, 0.406f},
                        new float[]{0.229f, 0.224f, 0.225f}));

        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .optApplySoftmax(true)
                .build();
4、加载你的模型resnet18
   System.setProperty("ai.djl.repository.zoo.location", "build/pytorch_models/resnet18");
    Criteria<Image, Classifications> criteria = Criteria.builder()
            .setTypes(Image.class, Classifications.class)
            // only search the model in local directory
            // "ai.djl.localmodelzoo:{name of the model}"
            .optArtifactId("ai.djl.localmodelzoo:resnet18")
            .optTranslator(translator)
            .optProgress(new ProgressBar()).build();
            ZooModel model = ModelZoo.loadModel(criteria);
5、使用图片进行预测
		// 自己本地
	    File fs=new File("D:\\testdjl\\dog.jpg");
        Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(fs));
        Predictor<Image, Classifications> predictor = model.newPredictor();
        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);
6、执行结果

在这里插入图片描述

三、在运行的时候可能会报如下的错
1、No deep learning engine found

官网给出地址如下:
https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md

我是 通过解决了nsatisfiedLinkError问题,解决的No deep learning engine found错误,官网有一个提示:CN:如果您在中国,可以使用DirectX修复工具来安装遗失依赖项。所以我就试着通过DirectX这个修复工具进行修复。
下载地址:https://www.onlinedown.net/soft/120082.htm,下载完之后安装就行了,如下:
在这里插入图片描述
2、路径中有中文(open file faild)在这里插入图片描述
将中文改成英文就好了

3、下载模型失败,记得翻墙

后记:

如果对你有所帮助,请记得点赞。

  • 20
    点赞
  • 103
    收藏
    觉得还不错? 一键收藏
  • 26
    评论
Java Spring Boot是一个开源的Java Web框架,它可以帮助开发者快速构建基于Spring的应用程序。而PyTorch是一个基于Python的科学计算库,它主要用于机器学习和深度学习领域。DJL是一个基于Java的深度学习框架,它可以与PyTorch模型进行集成。在Java Spring Boot中使用DJL可以方便地调用Python训练的PyTorch模型实现机器学习和深度学习的功能。 以下是Java Spring Boot使用DJL部署Python训练的PyTorch模型的步骤: 1. 在pom.xml文件中添加DJL的依赖: ```xml <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.11.0</version> </dependency> ``` 2. 创建一个Translator,用于将输入数据转换为NDArray(Tensor)类型: ```java public class MyTranslator implements Translator<String, NDArray> { @Override public NDArray processInput(TranslatorContext ctx, String input) throws Exception { // 将输入数据转换为NDArray类型 float[] data = new float[input.length()]; for (int i = 0; i < input.length(); i++) { data[i] = (float) (input.charAt(i) - '0'); } return NDArray.create(data, new Shape(1, input.length())); } @Override public String processOutput(TranslatorContext ctx, NDArray output) throws Exception { // 将输出数据转换为String类型 return String.valueOf(output.argMax().getInt()); } } ``` 3. 创建一个Predictor,用于加载PyTorch模型并进行预测: ```java public class MyPredictor { private final Predictor<String, NDArray> predictor; public MyPredictor() throws IOException, ModelException { // 加载PyTorch模型 Criteria<NDArray, String> criteria = Criteria.builder() .setTypes(NDArray.class, String.class) .optModelUrls("file:///path/to/model.pt") .optTranslator(new MyTranslator()) .build(); Model model = Model.newInstance(); model.setBlock(new Mlp(784, 10, new int[]{128, 64})); model.load(criteria); // 创建Predictor predictor = model.newPredictor(new MyTranslator()); } public String predict(String input) { // 进行预测 return predictor.predict(input); } } ``` 4. 在Controller中调用Predictor进行预测: ```java @RestController public class MyController { private final MyPredictor predictor; public MyController() throws IOException, ModelException { predictor = new MyPredictor(); } @GetMapping("/predict") public String predict(@RequestParam String input) { // 调用Predictor进行预测 return predictor.predict(input); } } ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值