Java使用pytorch模型进行数据推算

我的Java后台需要对数据进行分析,但找不到合适的方法,就准备用pytorch写个模型凑活着用。

使用的DJL调用pytorch引擎

Github:djl/README.md at master · deepjavalibrary/djl · GitHub

pom.xml中添加依赖:

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.16.0</version>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-auto</artifactId>
    <version>1.9.1</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>1.9.1-0.16.0</version>
    <scope>runtime</scope>
</dependency>

注意version与pytorch版本有一个对应关系

PyTorch engine versionPyTorch native library version
pytorch-engine:0.15.0pytorch-native-auto: 1.8.1, 1.9.1, 1.10.0
pytorch-engine:0.14.0pytorch-native-auto: 1.8.1, 1.9.0, 1.9.1
pytorch-engine:0.13.0pytorch-native-auto:1.9.0
pytorch-engine:0.12.0pytorch-native-auto:1.8.1
pytorch-engine:0.11.0pytorch-native-auto:1.8.1
pytorch-engine:0.10.0pytorch-native-auto:1.7.1
pytorch-engine:0.9.0pytorch-native-auto:1.7.0
pytorch-engine:0.8.0pytorch-native-auto:1.6.0
pytorch-engine:0.7.0pytorch-native-auto:1.6.0
pytorch-engine:0.6.0pytorch-native-auto:1.5.0
pytorch-engine:0.5.0pytorch-native-auto:1.4.0
pytorch-engine:0.4.0pytorch-native-auto:1.4.0

其他问题访问连接:PyTorch Engine - Deep Java Library


官方给出了一个图片分类的例子,我只需要纯数据不需要图片输入。

随便写了个例子 输入是[a, b] 输出一个0~1的数

还是建议用python先训练好模型,不要用Java训练。模型训练好后,首先要做的是把pytorch模型转为TorchScript,TorchScript会把模型结构和参数都加载进去的

官网原文:

There are two ways to convert your model to TorchScript: tracing and scripting. We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation. When tracing, we use an example input to record the actions taken and capture the the model architecture. This works best when your model doesn't have control flow. If you do have control flow, you will need to use the scripting approach. In DJL, we use tracing to create TorchScript for our ModelZoo models.

Here is an example of tracing in actions:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")

如果你使用了dropout等 一定要记得加上model.eval()再保存

对于我的来说 就下面这样

model = LinearModel()

model.load_state_dict(torch.load("model.pth"))

input = torch.tensor([0.72, 0.94]).float() //根据你的模型随便创建一个输入
    
script = torch.jit.trace(model, input)
    
script.save("model.pt")

然后该写Java代码了

官网例子:Load a PyTorch Model - Deep Java Library

还有这个:03 image classification with your model - Deep Java Library

我的数据就不需要transform了 代码:

//首先创建一个模型
Model model = Model.newInstance("test");
        try {
            model.load(Paths.get("C:\\Users\\Administrator\\IdeaProjects\\PytorchInJava\\src\\main\\resources\\model.pt"));
            System.out.println(model);

            //Predictor<参数类型,返回值类型> 输入图片的话参数是Image
            //我的参数是float32 不要写成Double
            Predictor<float[], Object> objectObjectPredictor = model.newPredictor(new NoBatchifyTranslator<float[], Object>() {
                @Override
                public NDList processInput(TranslatorContext translatorContext, float[] input) throws Exception {
                    NDManager ndManager = translatorContext.getNDManager();
                    NDArray ndArray = ndManager.create(input);
                    //ndArray作为输入
                    System.out.println(ndArray);
                    return new NDList(ndArray);
                }
                @Override
                public Object processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
                    System.out.println("process: " + ndList.get(0).getFloat());
                    return ndList.get(0).getFloat();
                }
            });

            float result = objectObjectPredictor.predict(new float[]{0.6144011f, 0.952401f});

            System.out.println("result: " + result);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (MalformedModelException e) {
            e.printStackTrace();
        } catch (Exception e) {
            System.out.println("qunimade ");
            e.printStackTrace();
        }

输出:

更新

当我打包成jar到centos7的linux中运行时,报错UnsatisfiedLinkError,经过大神的指导,问题出在我引的依赖。

修改后的依赖:

    <properties>
        <java.version>8</java.version>
        <jna.version>5.3.0</jna.version>
    </properties>


    <dependencies>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.16.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu-precxx11</artifactId>
            <classifier>linux-x86_64</classifier>
            <version>1.9.1</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-jni</artifactId>
            <version>1.9.1-0.16.0</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
    </dependencies>

  • 4
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 23
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 23
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

欧内的手好汗

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值