DJL Java环境下部署pytorch模型推理

由于大数据基本都是Java环境,希望与深度学习结合的话,需要将深度学习模型部署在Java环境下。传统方式使用flask搭建接口,在Java环境中对其调用,但通信时间和内存问题限制了这种方式的发展。

DJL是采用Java编写的深度学习框架,支持MXnet,Tensorflow,Pytorch引擎,这意味着同一个模型采用不同语言编写,在DJL框架中运行只需要更改依赖,代码完全一样即可执行。关于DJL更多的介绍大家可以浏览DJL官网,知乎,以及b站的课程。

知乎专栏:DJL深度学习库 - 知乎

b站课程录播:深度学习兽的个人空间_哔哩哔哩_Bilibili 

GitHub:DeepJavaLibrary · GitHub 

下面介绍部署pytorch模型步骤以及我个人遇到的一些坑,希望对大家有所帮助

首先是pom文件依赖

import torch
print(torch.__version__)

 首先使用该命令查看本地环境下的pytorch版本,根据本地的pytorch版本,选取合适的engine

PyTorch Engine - Deep Java Library

这是DJL官网的例子,也包含Linux和maxOS下的依赖配置,我的pytorch版本是1.9.0,给出我的pom文件做参考

 <dependencies>
        
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.18.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.18.0</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.9.1</version>
            <scope>runtime</scope>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <classifier>win-x86_64</classifier>
            <scope>runtime</scope>
            <version>1.11.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-jni</artifactId>
            <version>1.11.0-0.18.0</version>
            <scope>runtime</scope>
        </dependency>

    </dependencies>

 对于加载自己的本地模型,踩到的两个坑,第一个就是如果该模型是用GPU训练的,那么之后推理也需要使用GPU,如果想用CPU推理,那就需要用CPU训练网络(这一条我不确定是否正确,只是我这样修改后确实没有报错了)第二个坑就是在python中保存模型时,要使用下面的代码

net.eval()

input = np.random.uniform(0, 1, (1,1, 2048, 1))
input = input.astype(np.float32)
input = torch.from_numpy(input)
script = torch.jit.trace(net, input)
script.save(save_path+"/"+"0726+"+str(test_acc)+".pt")

使用script.save保存模型,之前我的代码是torch.save,保存的模型在DJL中加载会报错

DJL加载model首先获取本地模型的url

Path modeldir = Paths.get("D:\\1.pt");

之后重写Translator,这个需要自定义模型的输入输出类型

Translator<NDList, Long> translator = new NoBatchifyTranslator<NDList, Long>() {
            @Override
            public NDList processInput(TranslatorContext translatorContext, NDList inputs) throws Exception {
                NDManager ndManager = translatorContext.getNDManager();
                NDArray ndArray = ndManager.create(new float[2048]).reshape(1,1,2048,1);
                //ndArray作为输入
                System.out.println(ndArray);
                return new NDList(ndArray);
            }
            @Override
            public Long processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
                System.out.println("process: " + ndList);
                System.out.println("process-1:" + ndList.get(0));
                System.out.println("process-2:" + ndList.get(0).argMax());
                NDArray tmp = ndList.get(0).argMax();
                Long label =  tmp.max().getLong();
                return  label;
            }

        };

这个输入还是有问题的,传入的NDList完全没用上,一直在定义新的ndArray

translator完成后,调用Criteria,加载模型

        Criteria<NDList, Long> criteria = Criteria.builder()
                .setTypes(NDList.class,Long.class)
                .optModelPath(modeldir)
                .optTranslator(translator)
                .build();

之后调用predictor,生成预测器 

Predictor<NDList, Long> predictor = criteria.loadModel().newPredictor();

创建样本,测试样本输出(由于translator的问题,这里传什么进去结果都一样)

NDManager manager = NDManager.newBaseManager();

NDArray array = manager.randomUniform(0, 1, new Shape(2048));

NDList testarray = new NDList(array);

Long result = predictor.predict(testarray);

System.out.println("result:" + result);

 

 

 

 

 

  • 1
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值