前提
对于需要对接豆包大模型知识库能力的个人用户或企业,可以参考该SDK,从官方获取的。包含流式和非流式请求 demo,基于知识库检索生成大模型对话能力!本人已对接实现验证,有疑问可以私聊!
代码如下
引入maven依赖
<!-- 豆包 -->
<dependency>
<groupId>com.volcengine</groupId>
<artifactId>volcengine-java-sdk-ark-runtime</artifactId>
<version>LATEST</version>
</dependency>
<dependency>
<groupId>com.volcengine</groupId>
<artifactId>volc-sdk-java</artifactId>
<version>1.0.206</version>
</dependency>
java代码
package org.example;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.volcengine.auth.ISignerV4;
import com.volcengine.auth.impl.SignerV4Impl;
import com.volcengine.model.Credentials;
import com.volcengine.service.SignableRequest;
import lombok.Data;
import lombok.experimental.Accessors;
import org.apache.http.Header;
import org.apache.http.HttpResponse;
import org.apache.http.NameValuePair;
import org.apache.http.client.HttpClient;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.*;
public class Main {
public static final String HOST = "api-knowledgebase.mlp.cn-beijing.volces.com"; // 知识库域名
public static final String SEARCH_KNOWLEDGE_PATH = "/api/knowledge/collection/search_knowledge"; // 知识库检索接口,建议您首次接入时使用该检索接口,其他检索接口后续不再进行维护
public static final String CHAT_COMPLETION_PATH = "/api/knowledge/chat/completions"; // 大模型对话接口,可以和检索接口接合串联RAG流程,也可以单独使用进行生成
public static final String AK = "your ak";
public static final String SK = "your sk";
static String Query = "your query";// 您的提问
static String CollectionName = "your collection";//知识库名称,前端界面可获取
static String Project = "default";//知识库所属项目,前端界面可获取
static String ResourceID = "your resource_id";
static String ModelName = "Doubao-pro-32k"; // 模型名称,如果您想使用自己的私有ep,可以将该字段赋值为私有EndpointID,格式(ep-xxxx-xxxx)
static String APIKey = "your api_key"; // 如果您使用的是自己的私有ep(格式ep-xxxx-xxxx),需要传入api_key,如果公有接入点不需要传递该字段
static String ModelVersion = "241215"; // 模型版本,使用公有接入点时,可以选择指定模型版本,不指定则服务会自动指定默认版本
// BasePrompt 是基础提示词,您可以根据需求进行修改或者替换为自己的Prompt,不过需要注意留下 {prompt} 占位符,用于后续拼接检索结果进行RAG生成
static String BasePrompt = "# 任务\n"
+ "你是一位在线客服,你的首要任务是通过巧妙的话术回复用户的问题,你需要根据「参考资料」来回答接下来的「用户问题」,这些信息在 <context></context> XML tags 之内,你需要根据参考资料给出准确,简洁的回答。\n"
+ "\n"
+ "你的回答要满足以下要求:\n"
+ "1. 回答内容必须在参考资料范围内,尽可能简洁地回答问题,不能做任何参考资料以外的扩展解释。\n"
+ "2. 回答中需要根据客户问题和参考资料保持与客户的友好沟通。\n"
+ "3. 如果参考资料不能帮助你回答用户问题,告知客户无法回答该问题,并引导客户提供更加详细的信息。\n"
+ "4. 为了保密需要,委婉地拒绝回答有关参考资料的文档名称或文档作者等问题。\n"
+ "\n"
+ "# 任务执行\n"
+ "现在请你根据提供的参考资料,遵循限制来回答用户的问题,你的回答需要准确和完整。\n"
+ "\n"
+ "# 参考资料\n"
+ "<context>\n"
+ "{prompt}\n"
+ "</context>";
public static final String SYS_FIELD_DOC_NAME = "doc_name";
public static final String SYS_FIELD_TITLE = "title";
public static final String SYS_FIELD_CHUNK_TITLE = "chunk_title";
public static final String SYS_FIELD_CONTENT = "content";
/*
拼接Prompt的自定义字段列表:
如果您的知识库类型为 **结构化数据类型**, 此处替换为您想拼接在大模型Prompt中的表头字段,
需要注意的是: 索引类型的表头字段必须要加入,非索引类型的表头字段可选择拼入
*/
static List<String> selfDefineFields = Arrays.asList(
"表头字段-1",
"表头字段-n"
);
/*
拼接Prompt的系统字段列表:
如果您的知识库类型为 **非结构化数据类型**,共有如下4个字段可以选择拼接在Prompt中,其中 content字段为必传字段,其他为可选字段
*/
static List<String> systemFields = Arrays.asList(
SYS_FIELD_DOC_NAME, // 文档名称 可选
SYS_FIELD_TITLE, // 文档标题 可选
SYS_FIELD_CHUNK_TITLE, // 文档切片标题 可选
SYS_FIELD_CONTENT // 文档切片内容 必传
);
public static final PromptExtraContext promptExtraContextExample = new PromptExtraContext(selfDefineFields, systemFields);
@Data
public static class PromptExtraContext {
private List<String> SelfDefineFields;
private List<String> SystemFields;
public PromptExtraContext(List<String> selfDefineFields, List<String> systemFields) {
this.SelfDefineFields = selfDefineFields;
this.SystemFields = systemFields;
}
}
@Data
public static class SearchKnowledgeRequest {
private String name;
private String project;
@JsonProperty("resource_id")
private String resourceId;
private String query;
private Integer limit;
@JsonProperty("dense_weight")
private Float denseWeight;
@JsonProperty("md_search")
private Boolean mdSearch;
@JsonProperty("query_param")
private QueryParam queryParam;
@JsonProperty("pre_processing")
private PreProcessing preProcessing;
@JsonProperty("post_processing")
private PostProcessing postProcessing;
}
@Data
public static class QueryParam {
@JsonProperty("doc_filter")
private Object docFilter;
}
@Data
public static class PreProcessing {
@JsonProperty("need_instruction")
private Boolean needInstruction;
private Boolean rewrite;
private List<MessageParam> messages;
@JsonProperty("return_token_usage")
private Boolean returnTokenUsage;
}
@Data
public static class PostProcessing {
@JsonProperty("rerank_switch")
private Boolean rerankSwitch;
@JsonProperty("rerank_model")
private String rerankModel;
@JsonProperty("rerank_only_chunk")
private Boolean rerankOnlyChunk;
@JsonProperty("retrieve_count")
private Integer retrieveCount;
@JsonProperty("endpoint_id")
private String endpointId;
@JsonProperty("chunk_diffusion_count")
private Integer chunkDiffusionCount;
@JsonProperty("chunk_group")
private Boolean chunkGroup;
@JsonProperty("chunk_score_aggr_type")
private String chunkScoreAggrType;
@JsonProperty("chunk_extra_content")
private Map<String, Object> chunkExtraContent;
@JsonProperty("get_attachment_link")
private Boolean getAttachmentLink;
}
@Data
public static class MessageParam {
private String role;
private Object content;
MessageParam(String role, Object content) {
this.role = role;
this.content = content;
}
}
@Data
public static class ChatMessageImageURL {
private String url;
ChatMessageImageURL(String url) {
this.url = url;
}
}
public enum ChatCompletionMessageContentPartType {
TEXT, IMAGE_URL
}
@Data
public static class ChatCompletionMessageContentPart {
private ChatCompletionMessageContentPartType type;
private String text;
private ChatMessageImageURL imageURL;
public ChatCompletionMessageContentPart(ChatCompletionMessageContentPartType type, String text, ChatMessageImageURL imageURL) {
this.type = type;
this.text = text;
this.imageURL = imageURL;
}
}
@Data
public static class BaseResponse<T> {
private Integer code;
private String message;
@JsonProperty("request_id")
private String requestId;
private T data;
}
@Data
public static class CollectionSearchKnowledgeResponseData {
@JsonProperty("collection_name")
private String collectionName; // 知识库collection_name
private Integer count; // 返回切片数
@JsonProperty("rewrite_query")
private String rewriteQuery; // 改写后的query,改写功能开启时返回
@JsonProperty("token_usage")
private TotalTokenUsage tokenUsage; // 模型使用详情
@JsonProperty("result_list")
private List<CollectionSearchResponseItem> resultList; //返回切片信息
}
@Data
public static class TotalTokenUsage {
@JsonProperty("embedding_token_usage")
private ModelTokenUsage embeddingUsage;
@JsonProperty("rerank_token_usage")
private Long rerankUsage;
@JsonProperty("llm_token_usage")
private ModelTokenUsage llmUsage;
@JsonProperty("rewrite_token_usage")
private ModelTokenUsage rewriteUsage;
}
/*
检索返回的切片信息,具体字段
*/
@Data
public static class CollectionSearchResponseItem {
private String id;
private String content;
@JsonProperty("md_content")
private String mdContent;
private Double score;
@JsonProperty("point_id")
private String pointId;
@JsonProperty("origin_text")
private String originText;
@JsonProperty("original_question")
private String originalQuestion;
@JsonProperty("chunk_title")
private String chunkTitle;
@JsonProperty("chunk_id")
private Integer chunkId;
@JsonProperty("process_time")
private Long processTime;
@JsonProperty("rerank_score")
private Double rerankScore;
@JsonProperty("doc_info")
private CollectionSearchResponseItemDocInfo docInfo;
@JsonProperty("recall_position")
private Integer recallPosition;
@JsonProperty("rerank_position")
private Integer rerankPosition;
@JsonProperty("chunk_type")
private String chunkType;
@JsonProperty("chunk_source")
private String chunkSource;
@JsonProperty("update_time")
private Long updateTime;
@JsonProperty("chunk_attachment")
private List<ChunkAttachment> chunkAttachmentList;
@JsonProperty("table_chunk_fields")
private List<PointTableChunkField> tableChunkFields;
@JsonProperty("original_coordinate")
private ChunkPositions originalCoordinate;
}
@Data
public static class CollectionSearchResponseItemDocInfo {
@JsonProperty("doc_id")
private String docId;
@JsonProperty("doc_name")
private String docName;
@JsonProperty("create_time")
private Long createTime;
@JsonProperty("doc_type")
private String docType;
@JsonProperty("doc_meta")
private String docMeta;
private String source;
private String title;
}
@Data
public static class ChunkAttachment {
private String uuid;
private String caption;
private String type;
private String link;
}
@Data
public static class PointTableChunkField {
@JsonProperty("field_name")
private String fieldName;
@JsonProperty("field_value")
private Object fieldValue; // Using Object to accommodate any type of field value
}
@Data
public static class ChunkPositions {
@JsonProperty("page_no")
private List<Integer> pageNo;
private List<List<Double>> bbox; // A list of lists to represent bounding box coordinates
}
@Data
public static class ChatCompletionRequest {
private String model;
@JsonProperty("model_version")
private String modelVersion;
private String project;
private Boolean Stream;
@JsonProperty("return_token_usage")
private Boolean ReturnTokenUsage;
@JsonProperty("api_key")
private String APIKey;
@JsonProperty("max_tokens")
private Integer MaxTokens;
private double Temperature;
private List<MessageParam> messages;
}
@Data
public static class CollectionChatCompletionResponseData {
@JsonProperty("generated_answer")
private String GenerateAnswer; // 模型生成文本
private String Usage; // 该字段是ModelTokenUsage类型序列化而来,如果需要获取usage信息,需要自行进行反序列化
@JsonProperty("reasoning_content")
private String ReasoningContent; //模型推理过程内容,仅推理模型会有
private Boolean End;//流式响应标记是否最后一个
}
@Data
public static class ModelTokenUsage {
@JsonProperty("prompt_tokens")
private Integer PromptTokens; // 请求文本的分词数
@JsonProperty("completion_tokens")
private Integer CompletionTokens; // 生成文本的分词数, LLM对话模型才有值, 其他模型都是0
@JsonProperty("total_tokens")
private Integer TotalTokens; // PromptTokens + CompletionTokens
}
public static class PromptResult {
String prompt;
List<String> imageURLs;
Exception error;
public PromptResult(String prompt, List<String> imageURLs, Exception error) {
this.prompt = prompt;
this.imageURLs = imageURLs;
this.error = error;
}
}
public static String toJson(Object obj) {
try {
// 创建 ObjectMapper 实例
ObjectMapper objectMapper = new ObjectMapper();
// 将对象转换为 JSON 字符串
return objectMapper.writeValueAsString(obj);
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
public static SignableRequest prepareRequest(String host, String path, String method, List<NameValuePair> params, String body, String ak, String sk) throws Exception {
SignableRequest request = new SignableRequest();
request.setMethod(method);
request.setHeader("Accept", "application/json");
request.setHeader("Content-Type", "application/json");
request.setHeader("Host", HOST);
request.setEntity(new StringEntity(body, "utf-8"));
URIBuilder builder = request.getUriBuilder();
builder.setScheme("https");
builder.setHost(host);
builder.setPath(path);
if (params != null) {
builder.setParameters(params);
}
RequestConfig requestConfig = RequestConfig.custom().setSocketTimeout(120000).setConnectTimeout(12000).build();
request.setConfig(requestConfig);
Credentials credentials = new Credentials("cn-north-1", "air");
credentials.setAccessKeyID(ak);
credentials.setSecretAccessKey(sk);
// 签名
ISignerV4 ISigner = new SignerV4Impl();
ISigner.sign(request, credentials);
return request;
}
/**
* 知识库检索请求参数生成:以下详细展示了部分参数的传递规则,其余参数请参考官方接口文档,如果您想快速接入进行测试,可以只传入参数集合的最小集,其余参数可以使用默认值。
* 必传参数如下:
* 1. resourceId (也可使用resource_id 或 name + project, 二选一)
* 2. query:用户问题
*/
public static SearchKnowledgeRequest GenerateSearchKnowledgeRequest() {
SearchKnowledgeRequest requestParams = new SearchKnowledgeRequest();
requestParams.setQuery(Query);
requestParams.setName(CollectionName); // 可选
requestParams.setProject(Project); // 可选
//requestParams.setResourceId(ResourceID);// 知识库resource_id (二选一,查询时,可使用resource_id 或 name + project)
requestParams.setLimit(10);// 检索最终返回的结果数量, 不传递默认返回10条
requestParams.setDenseWeight(0.5f);//
/*
检索预处理结构体,提供改写和意图识别等附加功能,可按需开启,非必传
*/
PreProcessing preprocessing = new PreProcessing();
preprocessing.setNeedInstruction(true);//设置向量化模型是否拼接指令,查看官方文档
preprocessing.setReturnTokenUsage(true);// 检索是否返回token使用信息
preprocessing.setRewrite(false); // 问题改写开关,检索时提供根据多轮对话进行改写功能,默认不开启
requestParams.setPreProcessing(preprocessing);
/*
检索后处理结构体,提供Rerank等功能,可按需开启,非必传
*/
PostProcessing postProcessing = new PostProcessing();
postProcessing.setRerankSwitch(false);// 重排开关,默认不开启
postProcessing.setRetrieveCount(25);// 进入重排的切片数量,重排打开时生效,需要大于limit,当limit=10,默认值为25,可按需调整
postProcessing.setChunkGroup(true);// 是否对召回切片按照文档进行聚合
return requestParams;
}
// 检索接口调用
public static BaseResponse<CollectionSearchKnowledgeResponseData> SearchKnowledge() throws Exception {
SearchKnowledgeRequest searchRequestParams = GenerateSearchKnowledgeRequest();
String searchRequestParamsJson = toJson(searchRequestParams);
//可自行处理做更加详细的错误处理
try {
SignableRequest signableRequest = prepareRequest(HOST, SEARCH_KNOWLEDGE_PATH, "POST", null, searchRequestParamsJson, AK, SK);
URI uri = new URIBuilder()
.setScheme("https")
.setHost(HOST)
.setPath(SEARCH_KNOWLEDGE_PATH)
.build();
HttpPost httpPost = new HttpPost(uri);
httpPost.setConfig(signableRequest.getConfig());
httpPost.setEntity(signableRequest.getEntity());
for (Header header : signableRequest.getAllHeaders()) {
httpPost.setHeader(header.getName(), header.getValue());
}
HttpClient httpClient = HttpClients.createDefault();
HttpResponse response = httpClient.execute(httpPost);
int statusCode = response.getStatusLine().getStatusCode();
String responseBody = EntityUtils.toString(response.getEntity());
BaseResponse<CollectionSearchKnowledgeResponseData> resp = new ObjectMapper().readValue(responseBody, new TypeReference<BaseResponse<CollectionSearchKnowledgeResponseData>>() {
});
return resp;
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
public static ChatCompletionRequest GenerateChatCompletionRequest(Boolean Stream, List<MessageParam> messages) {
ChatCompletionRequest requestParams = new ChatCompletionRequest();
requestParams.setModel(ModelName);// 传入模型名称即代表使用公有接入点, 如果使用私有ep,此处替换为私有ep即可,格式 ep-xxx-xxx
requestParams.setModelVersion(ModelVersion);// 模型版本,使用公有接入点时,可以选择指定模型版本,不指定则服务会自动指定默认版本
requestParams.setReturnTokenUsage(true); //是否返回LLM调用的token使用情况
requestParams.setAPIKey(APIKey); // 使用私有ep(即model为ep-xxx-xxx)时,必须传递此参数才能生效,
requestParams.setMaxTokens(4096); // 输出最大token数量,目前最大支持4096
requestParams.setTemperature(0.7f); // 模型温度,取值范围0~1,值越大随机性越大
requestParams.setStream(Stream); // 模型结果是否流式返回
requestParams.setMessages(messages); // 模型对话信息,需要保证拼接格式,对话顺序正确,详细请查看官方文档
return requestParams;
}
public static String GetContentForPrompt(CollectionSearchResponseItem item, int imageNum) {
String content = item.content;
if (!Objects.equals(item.originalQuestion, "")) {
return String.format("当询问到相似问题时,请参考对应答案进行回答:问题:“%s”。答案:“%s”", item.getOriginalQuestion(), content);
}
if (imageNum > 0 && item.getChunkAttachmentList() != null && !item.getChunkAttachmentList().isEmpty()
&& item.getChunkAttachmentList().get(0).getLink() != null && !item.getChunkAttachmentList().get(0).getLink().isEmpty()) {
String placeholder = String.format("<img>图片%d</img>", imageNum);
return content + placeholder;
}
return content;
}
// 根据检索结果和提供的BasePrompt拼接生成用于调用大模型的SysPrompt
public static PromptResult GeneratePrompt(BaseResponse<CollectionSearchKnowledgeResponseData> resp) {
if (resp == null) {
return new PromptResult("", null, new IllegalArgumentException("response is nil"));
}
if (resp.getCode() != 0) {
return new PromptResult("", null, new IllegalArgumentException(resp.getMessage()));
}
StringBuilder promptBuilder = new StringBuilder();
List<String> imageURLs = new ArrayList<>();
boolean usingVLM = isVisionModel("ModelName");
int imageCnt = 0;
for (CollectionSearchResponseItem item : resp.data.resultList) {
if (usingVLM && item.getChunkAttachmentList() != null && !item.getChunkAttachmentList().isEmpty()) {
String link = item.getChunkAttachmentList().get(0).getLink();
if (link != null && !link.isEmpty()) {
imageURLs.add(link);
imageCnt++;
}
}
CollectionSearchResponseItemDocInfo docInfo = item.getDocInfo();
// 拼接用户指定的系统字段
for (String sysField : promptExtraContextExample.SystemFields) {
switch (sysField) {
case SYS_FIELD_DOC_NAME:
promptBuilder.append(String.format("%s: %s\n", sysField, docInfo.getDocName()));
break;
case SYS_FIELD_TITLE:
promptBuilder.append(String.format("%s: %s\n", sysField, docInfo.getTitle()));
break;
case SYS_FIELD_CHUNK_TITLE:
promptBuilder.append(String.format("%s: %s\n", sysField, item.getChunkTitle()));
break;
case SYS_FIELD_CONTENT:
promptBuilder.append(String.format("%s: %s\n", sysField, GetContentForPrompt(item, imageCnt)));
break;
}
}
// 注:针对知识库为结构化类型- prompt拼接用户指定的自定义字段(非结构化的知识库不用处理)
for (String selfField : promptExtraContextExample.SelfDefineFields) {
if (item.getTableChunkFields() != null && !item.getTableChunkFields().isEmpty()) {
for (PointTableChunkField tableChunkField : item.getTableChunkFields()) {
if (tableChunkField.getFieldName().equals(selfField)) {
promptBuilder.append(String.format("%s: %s\n", tableChunkField.getFieldName(), tableChunkField.getFieldValue()));
}
}
}
}
promptBuilder.append("---\n");
}
String finalPrompt = BasePrompt.replace("{prompt}", promptBuilder);
//System.out.println(finalPrompt);
return new PromptResult(finalPrompt, imageURLs, null);
}
/**
* 检查模型名称是否包含 "vision" 字符串,以判断是否是视觉模型。
*
* @param modelName 模型名称
* @return 如果模型名称包含 "vision",则返回 true,否则返回 false。
*/
public static boolean isVisionModel(String modelName) {
return modelName != null && modelName.contains("vision");
}
public static BaseResponse<CollectionChatCompletionResponseData> ChatCompletion(List<MessageParam> messages) throws Exception {
ChatCompletionRequest chatCompletionRequest = GenerateChatCompletionRequest(false, messages);
String chatCompletionRequestParamsJson = toJson(chatCompletionRequest);
try {
SignableRequest signableRequest = prepareRequest(HOST, CHAT_COMPLETION_PATH, "POST", null, chatCompletionRequestParamsJson, AK, SK);
URI uri = new URIBuilder()
.setScheme("https")
.setHost(HOST)
.setPath(CHAT_COMPLETION_PATH)
.build();
HttpPost httpPost = new HttpPost(uri);
httpPost.setConfig(signableRequest.getConfig());
httpPost.setEntity(signableRequest.getEntity());
for (Header header : signableRequest.getAllHeaders()) {
httpPost.setHeader(header.getName(), header.getValue());
}
HttpClient httpClient = HttpClients.createDefault();
HttpResponse response = httpClient.execute(httpPost);
int statusCode = response.getStatusLine().getStatusCode();
String responseBody = EntityUtils.toString(response.getEntity());
BaseResponse<CollectionChatCompletionResponseData> resp = new ObjectMapper().readValue(responseBody, new TypeReference<BaseResponse<CollectionChatCompletionResponseData>>() {
});
return resp;
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
/**
* 大模型流式调用 Chat-Completion,流式返回生成内容和其他所有信息
*/
public static void ChatCompletionStream(List<MessageParam> messages) throws Exception {
ChatCompletionRequest chatCompletionRequest = GenerateChatCompletionRequest(true, messages);
// 序列化
String chatCompletionRequestParamsJson = toJson(chatCompletionRequest);
try {
SignableRequest signableRequest = prepareRequest(HOST, CHAT_COMPLETION_PATH, "POST", null, chatCompletionRequestParamsJson, AK, SK);
URI uri = new URIBuilder()
.setScheme("https")
.setHost(HOST)
.setPath(CHAT_COMPLETION_PATH)
.build();
HttpPost httpPost = new HttpPost(uri);
httpPost.setConfig(signableRequest.getConfig());
httpPost.setEntity(signableRequest.getEntity());
for (Header header : signableRequest.getAllHeaders()) {
httpPost.setHeader(header.getName(), header.getValue());
}
httpPost.setHeader("Accept", "text/event-stream"); // 设置Accept头为text/event-stream
HttpClient httpClient = HttpClients.createDefault();
HttpResponse response = httpClient.execute(httpPost);
InputStream stream = response.getEntity().getContent();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
StringBuilder answer = new StringBuilder();
String usage = "";
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith("data:") && line.length() > 5) {
String content = line.substring(5);
BaseResponse<CollectionChatCompletionResponseData> resp = new ObjectMapper().readValue(content, new TypeReference<BaseResponse<CollectionChatCompletionResponseData>>() {
});
System.out.println(resp.data.GenerateAnswer);
answer.append(resp.data.GenerateAnswer);
// 最后一个stream中包含usage信息
if (resp.data.getEnd() != null && resp.data.End) {
usage = resp.data.Usage;
//System.out.println(resp.data.Usage);
}
}
}
System.out.printf("大模型流式调用返回结果: %s\n", answer);
System.out.printf("大模型流式调用返回token使用情况: %s\n", usage);
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
public static List<MessageParam> BuildMessages(String prompt, String query, List<String> images) {
List<MessageParam> messages = new ArrayList<>();
if (!images.isEmpty()) {
/*
如果采用的是VLM的模型,使用该分支拼接生成message, 拼接的message的role顺序如下:
[system, user, assistant, user, assistant, user,assistant, user...]
*/
List<ChatCompletionMessageContentPart> multiModalMessage = new ArrayList<>();
multiModalMessage.add(new ChatCompletionMessageContentPart(ChatCompletionMessageContentPartType.TEXT, query, null));
for (String imageURL : images) {
multiModalMessage.add(new ChatCompletionMessageContentPart(ChatCompletionMessageContentPartType.IMAGE_URL, null, new ChatMessageImageURL(imageURL)));
}
messages.add(new MessageParam("system", prompt));
messages.add(new MessageParam("user", multiModalMessage));
} else {
/*
如果使用的是普通的文本LLM模型,使用该分支拼接生成message,拼接的多轮对话message的role顺序如下:
[system, user, assistant, user, assistant, user,assistant, user...]
*/
messages.add(new MessageParam("system", prompt)); // 系统Prompt
messages.add(new MessageParam("user", query)); // 用户首个提问
//messages.add(new MessageParam("assistant", new ChatCompletionMessageContent(query, null))); // LLM回答
//messages.add(new MessageParam("user", new ChatCompletionMessageContent(query, null))); // 用户第二轮提问
//messages.add(new MessageParam("assistant", new ChatCompletionMessageContent(query, null))); // LLM第二轮回答
//...
}
//System.out.println("messages: " + messages);
return messages;
}
public static void RAG(Boolean stream) throws Exception {
try {
// 1.先进行知识库检索
BaseResponse<CollectionSearchKnowledgeResponseData> searchResponse = SearchKnowledge();
// 2.生成大模型提示词Prompt
PromptResult promptResult = GeneratePrompt(searchResponse);
// 3.拼接message对话信息
List<MessageParam> messages = BuildMessages(promptResult.prompt, Query, promptResult.imageURLs);
// 4.调用大模型LLM进行生成(流式/非流式)
if (stream) {
// 流式调用
ChatCompletionStream(messages);
} else {
// 非流式调用
BaseResponse<CollectionChatCompletionResponseData> chatResponse = ChatCompletion(messages);
String generatedAnswer = chatResponse.data.GenerateAnswer;
//System.out.printf("大模型非流式调用返回生成结果: %s\n",generatedAnswer);
String usage = chatResponse.data.Usage; // 字符串,可自行进行反序列化解析
//System.out.printf("大模型流式调用返回token使用情况: %s\n",usage);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// 3种流程,可以按照自身需求使用
public static void main(String[] args) throws Exception {
// 1.只使用知识库检索
//BaseResponse<CollectionSearchKnowledgeResponseData> searchResponse = SearchKnowledge();
//2. RAG检索+生成流程-非流式
//RAG(false);
//3. RAG检索+生成流程-流式
RAG(true);
}
}