参考
https://docs.djl.ai/jupyter/load_pytorch_model.html
引入依赖
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>djl</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.19.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<scope>runtime</scope>
<version>1.12.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>1.12.1-0.19.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.19.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.19.0</version>
</dependency>
<dependency>
<groupId>ai.djl.opencv</groupId>
<artifactId>opencv</artifactId>
<version>0.19.0</version>
</dependency>
</dependencies>
</project>
主要代码
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.nio.file.Paths;
public class djl {
public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
System.out.println(1111);
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.addTransform(new Resize(256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[] {0.485f, 0.456f, 0.406f},
new float[] {0.229f, 0.224f, 0.225f}))
.optApplySoftmax(true)
.build();
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("build/pytorch_models/resnet18"))
.optOption("mapLocation", "true") // this model requires mapLocation for GPU
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel model = criteria.loadModel();
Image img = ImageFactory.getInstance().fromUrl("https://pics1.baidu.com/feed/f31fbe096b63f62423429aa837c752f11b4ca38f.jpeg@f_auto?token=49eaf3a702a1e10e3ba622433d1ce83e");
img.getWrappedImage();
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
System.out.println(classifications.toString());
}
}
结果
[
class: "n01592084 chickadee", probability: 0.62068
class: "n01582220 magpie", probability: 0.18649
class: "n01580077 jay", probability: 0.06324
class: "n01560419 bulbul", probability: 0.02586
class: "n01601694 water ouzel, dipper", probability: 0.02014
]