一、前言
近期学习了DJL(深度java学习),有了一点小的研究成果,特以此博客分享给大家。这个技术是一个特别新的技术,是亚马逊云服务在2019年re:Invent大会推出的专为Java开发者量身定制的深度学习框架,网上的资料比较少,只有官方文档可以参考,研究起来难度比较大,但是经过不懈的努力,终于搞定了,接下来以官网的demo入门。由于这块有很多坑,所以有必要好好的说一下。
官网地址:https://docs.djl.ai/jupyter/load_pytorch_model.html
二、demo
1、创建SpringBoot项目,导入pom依赖
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.6.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.6.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.5.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.26</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.7.26</version>
</dependency>
<dependency>
<groupId>net.java.dev.jna</groupId>
<artifactId>jna</artifactId>
<version>5.3.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<scope>runtime</scope>
<version>1.5.0</version>
</dependency>
</dependencies>
2、下载模型
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
下载完成后生成一个build文件夹,里面有你的模型(如下下载失败,请翻墙,连接外网)
3、创建一个Translator
Pipeline pipeline = new Pipeline();
pipeline.add(new Resize(256))
.add(new CenterCrop(224, 224))
.add(new ToTensor())
.add(new Normalize(
new float[]{0.485f, 0.456f, 0.406f},
new float[]{0.229f, 0.224f, 0.225f}));
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.optApplySoftmax(true)
.build();
4、加载你的模型resnet18
System.setProperty("ai.djl.repository.zoo.location", "build/pytorch_models/resnet18");
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
// only search the model in local directory
// "ai.djl.localmodelzoo:{name of the model}"
.optArtifactId("ai.djl.localmodelzoo:resnet18")
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel model = ModelZoo.loadModel(criteria);
5、使用图片进行预测
// 自己本地
File fs=new File("D:\\testdjl\\dog.jpg");
Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(fs));
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
6、执行结果
三、在运行的时候可能会报如下的错
1、No deep learning engine found
官网给出地址如下:
https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md
我是 通过解决了nsatisfiedLinkError问题,解决的No deep learning engine found错误,官网有一个提示:CN:如果您在中国,可以使用DirectX修复工具来安装遗失依赖项
。所以我就试着通过DirectX这个修复工具进行修复。
下载地址:https://www.onlinedown.net/soft/120082.htm,下载完之后安装就行了,如下:
2、路径中有中文(open file faild)
将中文改成英文就好了
3、下载模型失败,记得翻墙
后记:
如果对你有所帮助,请记得点赞。