“嘿,亲爱的小伙伴们!👋 今天我怀揣着满满的热情,精心准备了一篇文章,希望能像一缕清风,拂过我们求知的心田,带来一些新的思考与启发。让我们携手并进,在知识的海洋中扬帆远航,共同探索、学习、成长吧!🌟”
今天我与大家分享一篇题为“调用AI大模型识别菜品”的技术文章。博主的开发经验尚浅,如果文章写的不完美有漏洞请勿喷。
在给出具体的源代码前,先大致说一下调用AI 大模型的简单步骤。
一、相关步骤:
-
需要引入相关的依赖,包括PaddlePaddle和PyTorch的引擎和模型库。在这篇文章中博主是应用百度深度学习平台PaddlePaddle进行的。
(介绍:PaddlePaddle(飞桨)是百度公司研发的开源深度学习平台,旨在支持大规模数据的多GPU和多机器并行训练。它最初源于2013年百度深度学习实验室创建的Paddle项目,于2016年9月正式开源,成为国内首个全面开源开放、技术领先、功能完备的产业级深度学习平台)
-
定义模型的推理类,在这个菜品识别的demo中我创建了DishesClassification类,用于封装与菜肴分类相关的推理逻辑。
-
初始化日志记录器:在类中初始化一个日志记录器,用于记录该类的日志信息
-
定义预测方法:定义
predict
方法,该方法接收一个图像对象,并返回分类结果。该方法内部调用classfier
方法获取初步的分类结果,并对结果进行后处理(如softmax归一化) 介绍:Softmax归一化,又称归一化指数函数,是数学,特别是概率论和相关领域中一种重要的函数。Softmax函数本质上是一个带有归一化步骤的指数映射。它能够将一个K维实数向量z“压缩”到另一个K维实向量σ(z)中,使得每一个元素的范围都在(0,1)之间,并且所有元素的和为1。这种特性使得Softmax函数在多分类问题中得到了广泛应用。
具体来说,Softmax函数的计算公式如下:
其中,zj是输入向量z的第j个元素,ezj是对zj求指数,确保输出为正数,分母∑Kk-1 ezk是对所有元素的指数值求和,以实现归一化。
特点与应用
概率输出:Softmax函数的输出可以看作是一个概率分布,每个元素代表输入向量属于某个类别的概率。这使得Softmax函数在多分类问题中非常有用,可以直接输出每个类别的预测概率。
增强差异:通过指数映射,Softmax函数能够将输入值拉伸到更大的范围,从而增强输入向量中元素之间的差异。这使得在分类问题中,更有可能正确识别出概率最大的类别。
广泛应用:Softmax函数在包括多项逻辑回归、多项线性判别分析、朴素贝叶斯分类器和人工神经网络等的多种基于概率的多分类问题方法中都有着广泛应用。在神经网络中,Softmax函数常用于输出层,将网络的输出转换为概率分布,以便于分类决策。
示例:
假设有一个输入向量z=[2.0, 1.0, 0.1],我们可以计算它的Softmax输出:
-
计算每个元素的指数:**e2.0**≈7.39,**e1.0**≈2.72,**e0.1**≈1.10
-
计算这些指数值的和:7.39+2.72+1.10≈11.21
-
计算归一化后的概率值:
7.3911.21≈0.662.7211.21≈0.241.1011.21≈0.10\dfrac{7.39}{11.21} ≈0.66 \qquad \dfrac{2.72}{11.21} ≈0.24 \qquad \dfrac{1.10}{11.21} ≈0.1011.217.39≈0.6611.212.72≈0.2411.211.10≈0.10
因此,输入向量z经过Softmax归一化后的输出为[0.66, 0.24, 0.10],表示该输入向量被归一化为一个概率分布。
-
定义模型加载和推理方法:定义一个
classfier
方法,该方法负责加载模型并对图像进行推理。- 使用
Criteria.builder()
构建一个模型加载的标准,指定使用的引擎(如PaddlePaddle)、模型路 径、模型名称、翻译器(用于将模型输出转换为可理解的格式)和进度条 - 加载模型:使用
ModelZoo.loadModel(criteria)
方法加载模型,该方法根据提供的标准加载指定的模型。 - 创建预测器,使用加载的模型创建一个Predictor对象,该对象用于对图像进行预测
- 执行预测:使用预测器的
predict
方法对传入的图像进行预测,获取结果 - 返回结果:将分类结果(
Classifications
对象)返回给调用者
- 使用
-
示例:
public class Main { public static void main(String[] args) { try { // 假设您已经有一个Image对象img Image img = ...; // 加载或创建图像对象 // 调用预测方法 Classifications classifications = DishesClassification.predict(img); // 处理分类结果 // 例如,打印分类名称和概率 for (Classifications.Classification item : classifications.items()) { System.out.println("Class: " + item.getClassName() + ", Probability: " + item.getProbability()); } } catch (IOException | ModelException | TranslateException e) { e.printStackTrace(); } } }
注意,上述示例中的
Image
对象和图像加载逻辑需要您根据实际情况进行实现。此外,DishTranslator
类和ProgressBar
类也需要您根据需要进行定义或引入。
二、具体代码
一、新建一个Maven工程,在pom.xml中引入以下依赖
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.dish</groupId>
<artifactId>dish_identification</artifactId>
<version>0.0.1-SNAPSHOT</version>
<packaging>jar</packaging>
<name>dish_identification</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<djl.version>0.17.0</djl.version>
</properties>
<dependencies>
<!-- 单元测试 -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.6</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.17.2</version>
</dependency>
<!-- 服务器端推理引擎 -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<!-- PaddlePaddle -->
<dependency>
<groupId>ai.djl.paddlepaddle</groupId>
<artifactId>paddlepaddle-model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.paddlepaddle</groupId>
<artifactId>paddlepaddle-engine</artifactId>
<version>${djl.version}</version>
</dependency>
</dependencies>
</project>
二、定义模型推理工具类,工具类中需要包含预测方法,模型加载和模型推理逻辑等
package com.dish.utils;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.nio.file.Paths;
public final class DishesClassification {
// 初始化一个日志记录器,用于记录该类的日志信息
private static final Logger logger = LoggerFactory.getLogger(DishesClassification.class);
private DishesClassification() {}
/**
* 图像预测逻辑
* @param img 传入的图片
* @return
* @throws IOException
* @throws ModelException
* @throws TranslateException
*/
public static Classifications predict(Image img) throws IOException, ModelException, TranslateException {
// 调用classfier(img)方法来获取初步的分类结果
Classifications classifications = DishesClassification.classfier(img);
// 遍历分类结果中的每个项目,计算概率的总和以及最大概率
List<Classifications.Classification> items = classifications.items();
double sum = 0;
double max = 0;
double[] probArr = new double[items.size()];
// 对概率进行softmax归一化处理(通过指数函数和最大概率的调整,然后计算新的概率分布)
// 将处理后的分类名称和概率存储在新的列表中,并返回一个新的Classifications对象
List<String> names = new ArrayList<>();
List<Double> probs = new ArrayList<>();
for (int i = 0; i < items.size(); i++) {
Classifications.Classification item = items.get(i);
double prob = item.getProbability();
probArr[i] = prob;
if (prob > max) max = prob;
}
for (int i = 0; i < items.size(); i++) {
probArr[i] = Math.exp(probArr[i] - max);
sum = sum + probArr[i];
}
for (int i = 0; i < items.size(); i++) {
Classifications.Classification item = items.get(i);
names.add(item.getClassName());
probs.add(probArr[i]);
}
return new Classifications(names, probs);
}
/**
* 使用Criteria类构建一个模型加载的标准,包括指定使用的引擎(PaddlePaddle)、模型路径、模型名称、翻译器(DishTranslator)和进度条。
* 使用ModelZoo.loadModel(criteria)加载模型,并在try-with-resources语句中确保模型资源在使用后被正确关闭。
* 使用加载的模型创建一个Predictor对象,并用它来预测图像的分类。
* 返回预测结果,即Classifications对象。
* */
public static Classifications classfier(Image img) throws IOException, ModelException, TranslateException {
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optEngine("PaddlePaddle")
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("models/dishes.zip"))
.optModelName("inference")
.optTranslator(new DishTranslator())
.optProgress(new ProgressBar())
.build();
try (ZooModel model = ModelZoo.loadModel(criteria)) {
try (Predictor<Image, Classifications> classifier = model.newPredictor()) {
Classifications classifications = classifier.predict(img);
return classifications;
}
}
}
}
三、写测试方法进行调用
package com.dish;
import ai.djl.ModelException;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.translate.TranslateException;
import com.dish.utils.DishesClassification;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
public final class DishesClassificationExample {
private static final Logger logger = LoggerFactory.getLogger(DishesClassificationExample.class);
private DishesClassificationExample() {}
public static void main(String[] args) throws IOException, ModelException, TranslateException {
// 这里放图片资源,进行识别
Path imageFile = Paths.get("src/main/resources/images/1.jpeg");
Image image = ImageFactory.getInstance().fromFile(imageFile);
Classifications classifications = DishesClassification.predict(image);
Classifications.Classification bestItem = classifications.best();
System.out.println(bestItem.getClassName() + " : " + bestItem.getProbability());
logger.info("{}", classifications);
}
}
四、查看执行的结果
下图是用于进行测试的图片:
执行结果(日志)
Loading: 100% |========================================|
WARNING: Logging before InitGoogleLogging() is written to STDERR
W1014 11:57:07.203495 19980 analysis_predictor.cc:1350] Deprecated. Please use CreatePredictor instead.
e[1me[35m--- Running analysis [ir_graph_build_pass]e[0m
e[1me[35m--- Running analysis [ir_graph_clean_pass]e[0m
e[1me[35m--- Running analysis [ir_analysis_pass]e[0m
e[32m--- Running IR pass [simplify_with_basic_ops_pass]e[0m
e[32m--- Running IR pass [layer_norm_fuse_pass]e[0m
e[37m--- Fused 0 subgraphs into layer_norm op.e[0m
e[32m--- Running IR pass [attention_lstm_fuse_pass]e[0m
e[32m--- Running IR pass [seqconv_eltadd_relu_fuse_pass]e[0m
e[32m--- Running IR pass [seqpool_cvm_concat_fuse_pass]e[0m
e[32m--- Running IR pass [mul_lstm_fuse_pass]e[0m
e[32m--- Running IR pass [fc_gru_fuse_pass]e[0m
e[37m--- fused 0 pairs of fc gru patternse[0m
e[32m--- Running IR pass [mul_gru_fuse_pass]e[0m
e[32m--- Running IR pass [seq_concat_fc_fuse_pass]e[0m
e[32m--- Running IR pass [squeeze2_matmul_fuse_pass]e[0m
e[32m--- Running IR pass [reshape2_matmul_fuse_pass]e[0m
e[32m--- Running IR pass [flatten2_matmul_fuse_pass]e[0m
e[32m--- Running IR pass [map_matmul_v2_to_mul_pass]e[0m
e[32m--- Running IR pass [map_matmul_v2_to_matmul_pass]e[0m
e[32m--- Running IR pass [map_matmul_to_mul_pass]e[0m
e[32m--- Running IR pass [fc_fuse_pass]e[0m
I1014 11:57:07.277503 19980 fuse_pass_base.cc:57] --- detected 1 subgraphs
e[32m--- Running IR pass [repeated_fc_relu_fuse_pass]e[0m
e[32m--- Running IR pass [squared_mat_sub_fuse_pass]e[0m
e[32m--- Running IR pass [conv_bn_fuse_pass]e[0m
I1014 11:57:07.328486 19980 fuse_pass_base.cc:57] --- detected 55 subgraphs
e[32m--- Running IR pass [conv_eltwiseadd_bn_fuse_pass]e[0m
e[32m--- Running IR pass [conv_transpose_bn_fuse_pass]e[0m
e[32m--- Running IR pass [conv_transpose_eltwiseadd_bn_fuse_pass]e[0m
e[32m--- Running IR pass [is_test_pass]e[0m
e[32m--- Running IR pass [runtime_context_cache_pass]e[0m
e[1me[35m--- Running analysis [ir_params_sync_among_devices_pass]e[0m
e[1me[35m--- Running analysis [adjust_cudnn_workspace_size_pass]e[0m
e[1me[35m--- Running analysis [inference_op_replace_pass]e[0m
e[1me[35m--- Running analysis [ir_graph_to_program_pass]e[0m
I1014 11:57:07.357105 19980 analysis_predictor.cc:714] ======= optimize end =======
I1014 11:57:07.358115 19980 naive_executor.cc:98] --- skip [feed], feed -> @HUB_resnet50_vd_dishes@image
I1014 11:57:07.359102 19980 naive_executor.cc:98] --- skip [save_infer_model/scale_0.tmp_2], fetch -> fetch
I1014 11:57:07.361105 19980 naive_executor.cc:98] --- skip [feed], feed -> @HUB_resnet50_vd_dishes@image
I1014 11:57:07.362116 19980 naive_executor.cc:98] --- skip [save_infer_model/scale_0.tmp_2], fetch -> fetch
热干面 : 1.0
Process finished with exit code 0
大模型&AI产品经理如何学习
求大家的点赞和收藏,我花2万买的大模型学习资料免费共享给你们,来看看有哪些东西。
1.学习路线图
第一阶段: 从大模型系统设计入手,讲解大模型的主要方法;
第二阶段: 在通过大模型提示词工程从Prompts角度入手更好发挥模型的作用;
第三阶段: 大模型平台应用开发借助阿里云PAI平台构建电商领域虚拟试衣系统;
第四阶段: 大模型知识库应用开发以LangChain框架为例,构建物流行业咨询智能问答系统;
第五阶段: 大模型微调开发借助以大健康、新零售、新媒体领域构建适合当前领域大模型;
第六阶段: 以SD多模态大模型为主,搭建了文生图小程序案例;
第七阶段: 以大模型平台应用与开发为主,通过星火大模型,文心大模型等成熟大模型构建大模型行业应用。
2.视频教程
网上虽然也有很多的学习资源,但基本上都残缺不全的,这是我自己整理的大模型视频教程,上面路线图的每一个知识点,我都有配套的视频讲解。
(都打包成一块的了,不能一一展开,总共300多集)
因篇幅有限,仅展示部分资料,需要点击下方图片前往获取
3.技术文档和电子书
这里主要整理了大模型相关PDF书籍、行业报告、文档,有几百本,都是目前行业最新的。
4.LLM面试题和面经合集
这里主要整理了行业目前最新的大模型面试题和各种大厂offer面经合集。
👉学会后的收获:👈
• 基于大模型全栈工程实现(前端、后端、产品经理、设计、数据分析等),通过这门课可获得不同能力;
• 能够利用大模型解决相关实际项目需求: 大数据时代,越来越多的企业和机构需要处理海量数据,利用大模型技术可以更好地处理这些数据,提高数据分析和决策的准确性。因此,掌握大模型应用开发技能,可以让程序员更好地应对实际项目需求;
• 基于大模型和企业数据AI应用开发,实现大模型理论、掌握GPU算力、硬件、LangChain开发框架和项目实战技能, 学会Fine-tuning垂直训练大模型(数据准备、数据蒸馏、大模型部署)一站式掌握;
• 能够完成时下热门大模型垂直领域模型训练能力,提高程序员的编码能力: 大模型应用开发需要掌握机器学习算法、深度学习框架等技术,这些技术的掌握可以提高程序员的编码能力和分析能力,让程序员更加熟练地编写高质量的代码。
1.AI大模型学习路线图
2.100套AI大模型商业化落地方案
3.100集大模型视频教程
4.200本大模型PDF书籍
5.LLM面试题合集
6.AI产品经理资源合集
👉获取方式:
😝有需要的小伙伴,可以保存图片到wx扫描二v码免费领取【保证100%免费】🆓