java使用Deep Java Library(djl)搭配TorchScript搭建图片分类

一、前置要求

1.1、下载TorchScript类型的模型,注意这里是TorchScript类型,有些模型在说明中会说明是否为该格式的文件。可以从huggingface下载,在huggingface注意未区分PyTorch和TorchScript,在模型下方的标签都标记的为PyTorch,需要看具体的描述是否说该模型为TorchScript。
1.2、pom文件引入依赖,注意和引擎相关的包需要搭配引用,例如ai.djl.pytorch的native和jni包与engine版本要对上。pom.xml引入包如下:

<properties>
		<maven.compiler.source>11</maven.compiler.source>
		<maven.compiler.target>11</maven.compiler.target>
		<djl.version>0.27.0</djl.version>
	</properties>


	<dependencies>
		<!-- https://mvnrepository.com/artifact/ai.djl/api -->
		<dependency>
			<groupId>ai.djl</groupId>
			<artifactId>api</artifactId>
			<version>${djl.version}</version>
		</dependency>

		<!-- https://mvnrepository.com/artifact/ai.djl/model-zoo -->
		<dependency>
			<groupId>ai.djl</groupId>
			<artifactId>model-zoo</artifactId>
			<version>${djl.version}</version>
		</dependency>

		<!-- https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-engine -->
		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-engine</artifactId>
			<version>${djl.version}</version>
		</dependency>

		<dependency>
			<groupId>ai.djl</groupId>
			<artifactId>basicdataset</artifactId>
			<version>${djl.version}</version>
		</dependency>
	
		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-engine</artifactId>
			<version>${djl.version}</version>
		</dependency>
		
		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-jni</artifactId>
			<version>2.1.1-0.27.0</version>
		</dependency>
		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-native-cpu</artifactId>
			<classifier>win-x86_64</classifier>
			<version>2.1.1</version>
		</dependency>

		<dependency>
			<groupId>ai.djl</groupId>
			<artifactId>djl-zero</artifactId>
			<version>${djl.version}</version>
		</dependency>

		<dependency>
			<groupId>org.apache.logging.log4j</groupId>
			<artifactId>log4j-slf4j-impl</artifactId>
			<version>2.21.0</version>
		</dependency>
	</dependencies>

二、java代码

将下载好的模型放到对应的位置,其中模型文件包含两个部分,一个是pt结尾的文件,当然结尾不一定是这个,可能是其他的,可以使用压缩文件打开这个模型文件看看是否包含了constants.pkl等文件,这个可以用作确认是否为TorchScript的一个标志。然后还需要一个synset.txt文件。

//这里也可以使用在线的模型
private static final URL MODEL_URL = NSFWUtil.class.getClassLoader().getResource("model/xxx.pt");

	public static void main(String[] args) throws MalformedModelException, IOException, ModelNotFoundException, TranslateException {
		getNSFW4JSON("image path");
	}
	
	/**
	 * 
	 * @param imagePath 文件地址
	 * @throws ModelNotFoundException 
	 * @throws MalformedModelException
	 * @throws IOException
	 * @throws TranslateException
	 * @return nsfw的json
	 */
	public static Classifications  getNSFW4JSON(String imagePath) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
		Image img = ImageFactory.getInstance().fromFile(Paths.get(imagePath));
		Translator<Image, Classifications> translator =
                ImageClassificationTranslator.builder()
                		.addTransform(new Resize(224, 224))
                        .addTransform(new ToTensor())
                        .optApplySoftmax(true)
                        .build();
		Criteria<Image, Classifications> criteria = Criteria.builder()
               .setTypes(Image.class, Classifications.class)
               .optModelUrls(MODEL_URL.toString())
               .optTranslator(translator)
               .optEngine("PyTorch") // Use PyTorch engine
               .optProgress(new ProgressBar())
               .build();
		try (ZooModel<Image, Classifications> model = criteria.loadModel())
		{
           Predictor<Image, Classifications> predictor = model.newPredictor();
           return predictor.predict(img);
       }
	}
	
	/**
	 * 
	 * @param in 输入流
	 * @throws ModelNotFoundException 
	 * @throws MalformedModelException
	 * @throws IOException
	 * @throws TranslateException
	 * @return nsfw的json
	 */
	public static Classifications  getNSFW4JSON(InputStream in) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
		Image img = BufferedImageFactory.getInstance().fromInputStream(in);
		Translator<Image, Classifications> translator =
                ImageClassificationTranslator.builder()
                		.addTransform(new Resize(224, 224))
                        .addTransform(new ToTensor())
                        .optApplySoftmax(true)
                        .build();
		Criteria<Image, Classifications> criteria = Criteria.builder()
               .setTypes(Image.class, Classifications.class)
               .optModelUrls(MODEL_URL.toString())
               .optTranslator(translator)
               .optEngine("PyTorch") // Use PyTorch engine
               .optProgress(new ProgressBar())
               .build();
		try (ZooModel<Image, Classifications> model = criteria.loadModel())
		{
           Predictor<Image, Classifications> predictor = model.newPredictor();
           return predictor.predict(img);
       }
	}

三、总结

3.1、代码中的ImageClassificationTranslator在其他很多时候是自己在定义具体的方法实现,这里我们是图片分类,所以我们用的是官方提供的Translator。
3.2、就目前来说框架帮我们实现了很多的代码,能写的代码不是很多,难点在于如何找到能用的模型,目前很多模型还是PyTorch类型的,无法在JAVA或者C++环境调用。
3.3、可以试一下的模型nsfw,记住下synset.txt

  • 8
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值