springboot 对接豆包大模型知识库 java SDK

前提

对于需要对接豆包大模型知识库能力的个人用户或企业,可以参考该SDK,从官方获取的。包含流式和非流式请求 demo,基于知识库检索生成大模型对话能力!本人已对接实现验证,有疑问可以私聊!

代码如下

引入maven依赖

        <!-- 豆包 -->
        <dependency>
            <groupId>com.volcengine</groupId>
            <artifactId>volcengine-java-sdk-ark-runtime</artifactId>
            <version>LATEST</version>
        </dependency>
        <dependency>
            <groupId>com.volcengine</groupId>
            <artifactId>volc-sdk-java</artifactId>
            <version>1.0.206</version>
        </dependency>

java代码

package org.example;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.volcengine.auth.ISignerV4;
import com.volcengine.auth.impl.SignerV4Impl;
import com.volcengine.model.Credentials;
import com.volcengine.service.SignableRequest;
import lombok.Data;
import lombok.experimental.Accessors;
import org.apache.http.Header;
import org.apache.http.HttpResponse;
import org.apache.http.NameValuePair;
import org.apache.http.client.HttpClient;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.*;

public class Main {
    public static final String HOST = "api-knowledgebase.mlp.cn-beijing.volces.com"; // 知识库域名
    public static final String SEARCH_KNOWLEDGE_PATH = "/api/knowledge/collection/search_knowledge"; // 知识库检索接口,建议您首次接入时使用该检索接口,其他检索接口后续不再进行维护
    public static final String CHAT_COMPLETION_PATH = "/api/knowledge/chat/completions";  // 大模型对话接口,可以和检索接口接合串联RAG流程,也可以单独使用进行生成

    public static final String AK = "your ak";
    public static final String SK = "your sk";


    static String Query = "your query";// 您的提问
    static String CollectionName = "your collection";//知识库名称,前端界面可获取
    static String Project = "default";//知识库所属项目,前端界面可获取
    static String ResourceID = "your resource_id";


    static String ModelName = "Doubao-pro-32k"; // 模型名称,如果您想使用自己的私有ep,可以将该字段赋值为私有EndpointID,格式(ep-xxxx-xxxx)
    static String APIKey = "your api_key"; // 如果您使用的是自己的私有ep(格式ep-xxxx-xxxx),需要传入api_key,如果公有接入点不需要传递该字段
    static String ModelVersion = "241215"; // 模型版本,使用公有接入点时,可以选择指定模型版本,不指定则服务会自动指定默认版本

    // BasePrompt 是基础提示词,您可以根据需求进行修改或者替换为自己的Prompt,不过需要注意留下 {prompt} 占位符,用于后续拼接检索结果进行RAG生成
    static String BasePrompt = "# 任务\n"
            + "你是一位在线客服,你的首要任务是通过巧妙的话术回复用户的问题,你需要根据「参考资料」来回答接下来的「用户问题」,这些信息在 <context></context> XML tags 之内,你需要根据参考资料给出准确,简洁的回答。\n"
            + "\n"
            + "你的回答要满足以下要求:\n"
            + "1. 回答内容必须在参考资料范围内,尽可能简洁地回答问题,不能做任何参考资料以外的扩展解释。\n"
            + "2. 回答中需要根据客户问题和参考资料保持与客户的友好沟通。\n"
            + "3. 如果参考资料不能帮助你回答用户问题,告知客户无法回答该问题,并引导客户提供更加详细的信息。\n"
            + "4. 为了保密需要,委婉地拒绝回答有关参考资料的文档名称或文档作者等问题。\n"
            + "\n"
            + "# 任务执行\n"
            + "现在请你根据提供的参考资料,遵循限制来回答用户的问题,你的回答需要准确和完整。\n"
            + "\n"
            + "# 参考资料\n"
            + "<context>\n"
            + "{prompt}\n"
            + "</context>";

    public static final String SYS_FIELD_DOC_NAME = "doc_name";
    public static final String SYS_FIELD_TITLE = "title";
    public static final String SYS_FIELD_CHUNK_TITLE = "chunk_title";
    public static final String SYS_FIELD_CONTENT = "content";

    /*
        拼接Prompt的自定义字段列表:
        如果您的知识库类型为 **结构化数据类型**, 此处替换为您想拼接在大模型Prompt中的表头字段,
        需要注意的是: 索引类型的表头字段必须要加入,非索引类型的表头字段可选择拼入
     */
    static List<String> selfDefineFields = Arrays.asList(
            "表头字段-1",
            "表头字段-n"
    );
    /*
         拼接Prompt的系统字段列表:
         如果您的知识库类型为 **非结构化数据类型**,共有如下4个字段可以选择拼接在Prompt中,其中 content字段为必传字段,其他为可选字段
     */
    static List<String> systemFields = Arrays.asList(
            SYS_FIELD_DOC_NAME, // 文档名称 可选
            SYS_FIELD_TITLE,  // 文档标题 可选
            SYS_FIELD_CHUNK_TITLE, // 文档切片标题 可选
            SYS_FIELD_CONTENT  // 文档切片内容 必传
    );

    public static final PromptExtraContext promptExtraContextExample = new PromptExtraContext(selfDefineFields, systemFields);

    @Data
    public static class PromptExtraContext {
        private List<String> SelfDefineFields;
        private List<String> SystemFields;

        public PromptExtraContext(List<String> selfDefineFields, List<String> systemFields) {
            this.SelfDefineFields = selfDefineFields;
            this.SystemFields = systemFields;
        }
    }


    @Data
    public static class SearchKnowledgeRequest {
        private String name;
        private String project;
        @JsonProperty("resource_id")
        private String resourceId;
        private String query;
        private Integer limit;
        @JsonProperty("dense_weight")
        private Float denseWeight;
        @JsonProperty("md_search")
        private Boolean mdSearch;
        @JsonProperty("query_param")
        private QueryParam queryParam;
        @JsonProperty("pre_processing")
        private PreProcessing preProcessing;
        @JsonProperty("post_processing")
        private PostProcessing postProcessing;
    }

    @Data
    public static class QueryParam {
        @JsonProperty("doc_filter")
        private Object docFilter;
    }

    @Data
    public static class PreProcessing {
        @JsonProperty("need_instruction")
        private Boolean needInstruction;
        private Boolean rewrite;
        private List<MessageParam> messages;
        @JsonProperty("return_token_usage")
        private Boolean returnTokenUsage;
    }

    @Data
    public static class PostProcessing {
        @JsonProperty("rerank_switch")
        private Boolean rerankSwitch;

        @JsonProperty("rerank_model")
        private String rerankModel;
        @JsonProperty("rerank_only_chunk")
        private Boolean rerankOnlyChunk;
        @JsonProperty("retrieve_count")
        private Integer retrieveCount;
        @JsonProperty("endpoint_id")
        private String endpointId;
        @JsonProperty("chunk_diffusion_count")
        private Integer chunkDiffusionCount;
        @JsonProperty("chunk_group")
        private Boolean chunkGroup;
        @JsonProperty("chunk_score_aggr_type")
        private String chunkScoreAggrType;
        @JsonProperty("chunk_extra_content")
        private Map<String, Object> chunkExtraContent;
        @JsonProperty("get_attachment_link")
        private Boolean getAttachmentLink;
    }

    @Data
    public static class MessageParam {
        private String role;
        private Object content;

        MessageParam(String role, Object content) {
            this.role = role;
            this.content = content;
        }
    }

    @Data
    public static class ChatMessageImageURL {
        private String url;

        ChatMessageImageURL(String url) {
            this.url = url;
        }

    }

    public enum ChatCompletionMessageContentPartType {
        TEXT, IMAGE_URL
    }

    @Data
    public static class ChatCompletionMessageContentPart {
        private ChatCompletionMessageContentPartType type;
        private String text;
        private ChatMessageImageURL imageURL;

        public ChatCompletionMessageContentPart(ChatCompletionMessageContentPartType type, String text, ChatMessageImageURL imageURL) {
            this.type = type;
            this.text = text;
            this.imageURL = imageURL;
        }
    }

    @Data
    public static class BaseResponse<T> {
        private Integer code;
        private String message;
        @JsonProperty("request_id")
        private String requestId;
        private T data;
    }

    @Data
    public static class CollectionSearchKnowledgeResponseData {
        @JsonProperty("collection_name")
        private String collectionName; // 知识库collection_name
        private Integer count; // 返回切片数
        @JsonProperty("rewrite_query")
        private String rewriteQuery; // 改写后的query,改写功能开启时返回
        @JsonProperty("token_usage")
        private TotalTokenUsage tokenUsage; // 模型使用详情
        @JsonProperty("result_list")
        private List<CollectionSearchResponseItem> resultList; //返回切片信息
    }

    @Data
    public static class TotalTokenUsage {
        @JsonProperty("embedding_token_usage")
        private ModelTokenUsage embeddingUsage;
        @JsonProperty("rerank_token_usage")
        private Long rerankUsage;
        @JsonProperty("llm_token_usage")
        private ModelTokenUsage llmUsage;
        @JsonProperty("rewrite_token_usage")
        private ModelTokenUsage rewriteUsage;
    }

    /*
        检索返回的切片信息,具体字段
     */
    @Data
    public static class CollectionSearchResponseItem {
        private String id;
        private String content;
        @JsonProperty("md_content")
        private String mdContent;
        private Double score;
        @JsonProperty("point_id")
        private String pointId;
        @JsonProperty("origin_text")
        private String originText;
        @JsonProperty("original_question")
        private String originalQuestion;
        @JsonProperty("chunk_title")
        private String chunkTitle;
        @JsonProperty("chunk_id")
        private Integer chunkId;
        @JsonProperty("process_time")
        private Long processTime;
        @JsonProperty("rerank_score")
        private Double rerankScore;
        @JsonProperty("doc_info")
        private CollectionSearchResponseItemDocInfo docInfo;
        @JsonProperty("recall_position")
        private Integer recallPosition;
        @JsonProperty("rerank_position")
        private Integer rerankPosition;
        @JsonProperty("chunk_type")
        private String chunkType;
        @JsonProperty("chunk_source")
        private String chunkSource;
        @JsonProperty("update_time")
        private Long updateTime;
        @JsonProperty("chunk_attachment")
        private List<ChunkAttachment> chunkAttachmentList;
        @JsonProperty("table_chunk_fields")
        private List<PointTableChunkField> tableChunkFields;
        @JsonProperty("original_coordinate")
        private ChunkPositions originalCoordinate;
    }

    @Data
    public static class CollectionSearchResponseItemDocInfo {
        @JsonProperty("doc_id")
        private String docId;
        @JsonProperty("doc_name")
        private String docName;
        @JsonProperty("create_time")
        private Long createTime;
        @JsonProperty("doc_type")
        private String docType;
        @JsonProperty("doc_meta")
        private String docMeta;
        private String source;
        private String title;
    }

    @Data
    public static class ChunkAttachment {
        private String uuid;
        private String caption;
        private String type;
        private String link;
    }

    @Data
    public static class PointTableChunkField {
        @JsonProperty("field_name")
        private String fieldName;
        @JsonProperty("field_value")
        private Object fieldValue;  // Using Object to accommodate any type of field value
    }

    @Data
    public static class ChunkPositions {
        @JsonProperty("page_no")
        private List<Integer> pageNo;
        private List<List<Double>> bbox;  // A list of lists to represent bounding box coordinates
    }


    @Data
    public static class ChatCompletionRequest {
        private String model;
        @JsonProperty("model_version")
        private String modelVersion;
        private String project;
        private Boolean Stream;
        @JsonProperty("return_token_usage")
        private Boolean ReturnTokenUsage;
        @JsonProperty("api_key")
        private String APIKey;
        @JsonProperty("max_tokens")
        private Integer MaxTokens;
        private double Temperature;
        private List<MessageParam> messages;
    }


    @Data
    public static class CollectionChatCompletionResponseData {
        @JsonProperty("generated_answer")
        private String GenerateAnswer; // 模型生成文本
        private String Usage; // 该字段是ModelTokenUsage类型序列化而来,如果需要获取usage信息,需要自行进行反序列化
        @JsonProperty("reasoning_content")
        private String ReasoningContent; //模型推理过程内容,仅推理模型会有
        private Boolean End;//流式响应标记是否最后一个
    }

    @Data
    public static class ModelTokenUsage {
        @JsonProperty("prompt_tokens")
        private Integer PromptTokens;  // 请求文本的分词数
        @JsonProperty("completion_tokens")
        private Integer CompletionTokens; // 生成文本的分词数, LLM对话模型才有值, 其他模型都是0
        @JsonProperty("total_tokens")
        private Integer TotalTokens; // PromptTokens + CompletionTokens
    }


    public static class PromptResult {
        String prompt;
        List<String> imageURLs;
        Exception error;

        public PromptResult(String prompt, List<String> imageURLs, Exception error) {
            this.prompt = prompt;
            this.imageURLs = imageURLs;
            this.error = error;
        }
    }


    public static String toJson(Object obj) {
        try {
            // 创建 ObjectMapper 实例
            ObjectMapper objectMapper = new ObjectMapper();
            // 将对象转换为 JSON 字符串
            return objectMapper.writeValueAsString(obj);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }


    public static SignableRequest prepareRequest(String host, String path, String method, List<NameValuePair> params, String body, String ak, String sk) throws Exception {
        SignableRequest request = new SignableRequest();
        request.setMethod(method);
        request.setHeader("Accept", "application/json");
        request.setHeader("Content-Type", "application/json");
        request.setHeader("Host", HOST);
        request.setEntity(new StringEntity(body, "utf-8"));

        URIBuilder builder = request.getUriBuilder();
        builder.setScheme("https");
        builder.setHost(host);
        builder.setPath(path);
        if (params != null) {
            builder.setParameters(params);
        }

        RequestConfig requestConfig = RequestConfig.custom().setSocketTimeout(120000).setConnectTimeout(12000).build();
        request.setConfig(requestConfig);

        Credentials credentials = new Credentials("cn-north-1", "air");
        credentials.setAccessKeyID(ak);
        credentials.setSecretAccessKey(sk);

        // 签名
        ISignerV4 ISigner = new SignerV4Impl();
        ISigner.sign(request, credentials);

        return request;
    }


    /**
     * 知识库检索请求参数生成:以下详细展示了部分参数的传递规则,其余参数请参考官方接口文档,如果您想快速接入进行测试,可以只传入参数集合的最小集,其余参数可以使用默认值。
     * 必传参数如下:
     * 1. resourceId (也可使用resource_id 或 name + project, 二选一)
     * 2. query:用户问题
     */
    public static SearchKnowledgeRequest GenerateSearchKnowledgeRequest() {
        SearchKnowledgeRequest requestParams = new SearchKnowledgeRequest();
        requestParams.setQuery(Query);
        requestParams.setName(CollectionName); // 可选
        requestParams.setProject(Project); // 可选
        //requestParams.setResourceId(ResourceID);// 知识库resource_id (二选一,查询时,可使用resource_id 或 name + project)
        requestParams.setLimit(10);// 检索最终返回的结果数量, 不传递默认返回10条
        requestParams.setDenseWeight(0.5f);//

        /*
            检索预处理结构体,提供改写和意图识别等附加功能,可按需开启,非必传
         */
        PreProcessing preprocessing = new PreProcessing();
        preprocessing.setNeedInstruction(true);//设置向量化模型是否拼接指令,查看官方文档
        preprocessing.setReturnTokenUsage(true);// 检索是否返回token使用信息
        preprocessing.setRewrite(false); // 问题改写开关,检索时提供根据多轮对话进行改写功能,默认不开启

        requestParams.setPreProcessing(preprocessing);

        /*
            检索后处理结构体,提供Rerank等功能,可按需开启,非必传
         */
        PostProcessing postProcessing = new PostProcessing();
        postProcessing.setRerankSwitch(false);// 重排开关,默认不开启
        postProcessing.setRetrieveCount(25);// 进入重排的切片数量,重排打开时生效,需要大于limit,当limit=10,默认值为25,可按需调整
        postProcessing.setChunkGroup(true);// 是否对召回切片按照文档进行聚合

        return requestParams;
    }

    // 检索接口调用
    public static BaseResponse<CollectionSearchKnowledgeResponseData> SearchKnowledge() throws Exception {
        SearchKnowledgeRequest searchRequestParams = GenerateSearchKnowledgeRequest();
        String searchRequestParamsJson = toJson(searchRequestParams);
        //可自行处理做更加详细的错误处理
        try {
            SignableRequest signableRequest = prepareRequest(HOST, SEARCH_KNOWLEDGE_PATH, "POST", null, searchRequestParamsJson, AK, SK);
            URI uri = new URIBuilder()
                    .setScheme("https")
                    .setHost(HOST)
                    .setPath(SEARCH_KNOWLEDGE_PATH)
                    .build();

            HttpPost httpPost = new HttpPost(uri);
            httpPost.setConfig(signableRequest.getConfig());
            httpPost.setEntity(signableRequest.getEntity());
            for (Header header : signableRequest.getAllHeaders()) {
                httpPost.setHeader(header.getName(), header.getValue());
            }

            HttpClient httpClient = HttpClients.createDefault();

            HttpResponse response = httpClient.execute(httpPost);

            int statusCode = response.getStatusLine().getStatusCode();
            String responseBody = EntityUtils.toString(response.getEntity());

            BaseResponse<CollectionSearchKnowledgeResponseData> resp = new ObjectMapper().readValue(responseBody, new TypeReference<BaseResponse<CollectionSearchKnowledgeResponseData>>() {
            });
            return resp;
        } catch (Exception e) {
            e.printStackTrace();
            throw e;
        }
    }


    public static ChatCompletionRequest GenerateChatCompletionRequest(Boolean Stream, List<MessageParam> messages) {
        ChatCompletionRequest requestParams = new ChatCompletionRequest();
        requestParams.setModel(ModelName);// 传入模型名称即代表使用公有接入点, 如果使用私有ep,此处替换为私有ep即可,格式 ep-xxx-xxx
        requestParams.setModelVersion(ModelVersion);// 模型版本,使用公有接入点时,可以选择指定模型版本,不指定则服务会自动指定默认版本
        requestParams.setReturnTokenUsage(true); //是否返回LLM调用的token使用情况
        requestParams.setAPIKey(APIKey); // 使用私有ep(即model为ep-xxx-xxx)时,必须传递此参数才能生效,
        requestParams.setMaxTokens(4096); // 输出最大token数量,目前最大支持4096
        requestParams.setTemperature(0.7f); // 模型温度,取值范围0~1,值越大随机性越大
        requestParams.setStream(Stream); // 模型结果是否流式返回
        requestParams.setMessages(messages); // 模型对话信息,需要保证拼接格式,对话顺序正确,详细请查看官方文档

        return requestParams;
    }


    public static String GetContentForPrompt(CollectionSearchResponseItem item, int imageNum) {
        String content = item.content;

        if (!Objects.equals(item.originalQuestion, "")) {
            return String.format("当询问到相似问题时,请参考对应答案进行回答:问题:“%s”。答案:“%s”", item.getOriginalQuestion(), content);
        }
        if (imageNum > 0 && item.getChunkAttachmentList() != null && !item.getChunkAttachmentList().isEmpty()
                && item.getChunkAttachmentList().get(0).getLink() != null && !item.getChunkAttachmentList().get(0).getLink().isEmpty()) {
            String placeholder = String.format("<img>图片%d</img>", imageNum);
            return content + placeholder;
        }
        return content;
    }


    // 根据检索结果和提供的BasePrompt拼接生成用于调用大模型的SysPrompt
    public static PromptResult GeneratePrompt(BaseResponse<CollectionSearchKnowledgeResponseData> resp) {
        if (resp == null) {
            return new PromptResult("", null, new IllegalArgumentException("response is nil"));
        }
        if (resp.getCode() != 0) {
            return new PromptResult("", null, new IllegalArgumentException(resp.getMessage()));
        }

        StringBuilder promptBuilder = new StringBuilder();
        List<String> imageURLs = new ArrayList<>();
        boolean usingVLM = isVisionModel("ModelName");
        int imageCnt = 0;

        for (CollectionSearchResponseItem item : resp.data.resultList) {
            if (usingVLM && item.getChunkAttachmentList() != null && !item.getChunkAttachmentList().isEmpty()) {
                String link = item.getChunkAttachmentList().get(0).getLink();
                if (link != null && !link.isEmpty()) {
                    imageURLs.add(link);
                    imageCnt++;
                }
            }

            CollectionSearchResponseItemDocInfo docInfo = item.getDocInfo();

            // 拼接用户指定的系统字段
            for (String sysField : promptExtraContextExample.SystemFields) {
                switch (sysField) {
                    case SYS_FIELD_DOC_NAME:
                        promptBuilder.append(String.format("%s: %s\n", sysField, docInfo.getDocName()));
                        break;
                    case SYS_FIELD_TITLE:
                        promptBuilder.append(String.format("%s: %s\n", sysField, docInfo.getTitle()));
                        break;
                    case SYS_FIELD_CHUNK_TITLE:
                        promptBuilder.append(String.format("%s: %s\n", sysField, item.getChunkTitle()));
                        break;
                    case SYS_FIELD_CONTENT:
                        promptBuilder.append(String.format("%s: %s\n", sysField, GetContentForPrompt(item, imageCnt)));
                        break;
                }
            }

            // 注:针对知识库为结构化类型- prompt拼接用户指定的自定义字段(非结构化的知识库不用处理)
            for (String selfField : promptExtraContextExample.SelfDefineFields) {
                if (item.getTableChunkFields() != null && !item.getTableChunkFields().isEmpty()) {
                    for (PointTableChunkField tableChunkField : item.getTableChunkFields()) {
                        if (tableChunkField.getFieldName().equals(selfField)) {
                            promptBuilder.append(String.format("%s: %s\n", tableChunkField.getFieldName(), tableChunkField.getFieldValue()));
                        }
                    }
                }
            }

            promptBuilder.append("---\n");
        }

        String finalPrompt = BasePrompt.replace("{prompt}", promptBuilder);
        //System.out.println(finalPrompt);
        return new PromptResult(finalPrompt, imageURLs, null);
    }


    /**
     * 检查模型名称是否包含 "vision" 字符串,以判断是否是视觉模型。
     *
     * @param modelName 模型名称
     * @return 如果模型名称包含 "vision",则返回 true,否则返回 false。
     */
    public static boolean isVisionModel(String modelName) {
        return modelName != null && modelName.contains("vision");
    }

    public static BaseResponse<CollectionChatCompletionResponseData> ChatCompletion(List<MessageParam> messages) throws Exception {
        ChatCompletionRequest chatCompletionRequest = GenerateChatCompletionRequest(false, messages);
        String chatCompletionRequestParamsJson = toJson(chatCompletionRequest);
        try {
            SignableRequest signableRequest = prepareRequest(HOST, CHAT_COMPLETION_PATH, "POST", null, chatCompletionRequestParamsJson, AK, SK);
            URI uri = new URIBuilder()
                    .setScheme("https")
                    .setHost(HOST)
                    .setPath(CHAT_COMPLETION_PATH)
                    .build();

            HttpPost httpPost = new HttpPost(uri);
            httpPost.setConfig(signableRequest.getConfig());
            httpPost.setEntity(signableRequest.getEntity());
            for (Header header : signableRequest.getAllHeaders()) {
                httpPost.setHeader(header.getName(), header.getValue());
            }

            HttpClient httpClient = HttpClients.createDefault();
            HttpResponse response = httpClient.execute(httpPost);

            int statusCode = response.getStatusLine().getStatusCode();
            String responseBody = EntityUtils.toString(response.getEntity());

            BaseResponse<CollectionChatCompletionResponseData> resp = new ObjectMapper().readValue(responseBody, new TypeReference<BaseResponse<CollectionChatCompletionResponseData>>() {
            });

            return resp;
        } catch (Exception e) {
            e.printStackTrace();
            throw e;
        }
    }


    /**
     * 大模型流式调用 Chat-Completion,流式返回生成内容和其他所有信息
     */
    public static void ChatCompletionStream(List<MessageParam> messages) throws Exception {
        ChatCompletionRequest chatCompletionRequest = GenerateChatCompletionRequest(true, messages);
        // 序列化
        String chatCompletionRequestParamsJson = toJson(chatCompletionRequest);
        try {
            SignableRequest signableRequest = prepareRequest(HOST, CHAT_COMPLETION_PATH, "POST", null, chatCompletionRequestParamsJson, AK, SK);
            URI uri = new URIBuilder()
                    .setScheme("https")
                    .setHost(HOST)
                    .setPath(CHAT_COMPLETION_PATH)
                    .build();

            HttpPost httpPost = new HttpPost(uri);
            httpPost.setConfig(signableRequest.getConfig());
            httpPost.setEntity(signableRequest.getEntity());
            for (Header header : signableRequest.getAllHeaders()) {
                httpPost.setHeader(header.getName(), header.getValue());
            }
            httpPost.setHeader("Accept", "text/event-stream"); // 设置Accept头为text/event-stream

            HttpClient httpClient = HttpClients.createDefault();
            HttpResponse response = httpClient.execute(httpPost);

            InputStream stream = response.getEntity().getContent();
            BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
            StringBuilder answer = new StringBuilder();
            String usage = "";
            String line;
            while ((line = reader.readLine()) != null) {
                if (line.startsWith("data:") && line.length() > 5) {
                    String content = line.substring(5);
                    BaseResponse<CollectionChatCompletionResponseData> resp = new ObjectMapper().readValue(content, new TypeReference<BaseResponse<CollectionChatCompletionResponseData>>() {
                    });
                    System.out.println(resp.data.GenerateAnswer);
                    answer.append(resp.data.GenerateAnswer);
                    // 最后一个stream中包含usage信息
                    if (resp.data.getEnd() != null && resp.data.End) {
                        usage = resp.data.Usage;
                        //System.out.println(resp.data.Usage);
                    }
                }
            }

            System.out.printf("大模型流式调用返回结果: %s\n", answer);
            System.out.printf("大模型流式调用返回token使用情况: %s\n", usage);
        } catch (Exception e) {
            e.printStackTrace();
            throw e;
        }
    }

    public static List<MessageParam> BuildMessages(String prompt, String query, List<String> images) {
        List<MessageParam> messages = new ArrayList<>();

        if (!images.isEmpty()) {
            /*
                如果采用的是VLM的模型,使用该分支拼接生成message, 拼接的message的role顺序如下:
                 [system, user, assistant, user, assistant, user,assistant, user...]
             */
            List<ChatCompletionMessageContentPart> multiModalMessage = new ArrayList<>();
            multiModalMessage.add(new ChatCompletionMessageContentPart(ChatCompletionMessageContentPartType.TEXT, query, null));

            for (String imageURL : images) {
                multiModalMessage.add(new ChatCompletionMessageContentPart(ChatCompletionMessageContentPartType.IMAGE_URL, null, new ChatMessageImageURL(imageURL)));
            }

            messages.add(new MessageParam("system", prompt));
            messages.add(new MessageParam("user", multiModalMessage));
        } else {
            /*
                如果使用的是普通的文本LLM模型,使用该分支拼接生成message,拼接的多轮对话message的role顺序如下:
                 [system, user, assistant, user, assistant, user,assistant, user...]
             */
            messages.add(new MessageParam("system", prompt)); // 系统Prompt
            messages.add(new MessageParam("user", query)); // 用户首个提问
            //messages.add(new MessageParam("assistant", new ChatCompletionMessageContent(query, null))); // LLM回答
            //messages.add(new MessageParam("user", new ChatCompletionMessageContent(query, null))); // 用户第二轮提问
            //messages.add(new MessageParam("assistant", new ChatCompletionMessageContent(query, null))); // LLM第二轮回答
            //...
        }
        //System.out.println("messages: " + messages);
        return messages;
    }


    public static void RAG(Boolean stream) throws Exception {
        try {
            // 1.先进行知识库检索
            BaseResponse<CollectionSearchKnowledgeResponseData> searchResponse = SearchKnowledge();

            // 2.生成大模型提示词Prompt
            PromptResult promptResult = GeneratePrompt(searchResponse);

            // 3.拼接message对话信息
            List<MessageParam> messages = BuildMessages(promptResult.prompt, Query, promptResult.imageURLs);

            // 4.调用大模型LLM进行生成(流式/非流式)
            if (stream) {
                // 流式调用
                ChatCompletionStream(messages);
            } else {
                // 非流式调用
                BaseResponse<CollectionChatCompletionResponseData> chatResponse = ChatCompletion(messages);
                String generatedAnswer = chatResponse.data.GenerateAnswer;
                //System.out.printf("大模型非流式调用返回生成结果: %s\n",generatedAnswer);
                String usage = chatResponse.data.Usage; // 字符串,可自行进行反序列化解析
                //System.out.printf("大模型流式调用返回token使用情况: %s\n",usage);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    // 3种流程,可以按照自身需求使用
    public static void main(String[] args) throws Exception {
        // 1.只使用知识库检索
        //BaseResponse<CollectionSearchKnowledgeResponseData> searchResponse = SearchKnowledge();
        //2. RAG检索+生成流程-非流式
        //RAG(false);
        //3. RAG检索+生成流程-流式
        RAG(true);
    }
}
### Spring Boot 集成豆包大模型 在现代应用程序开发中,将机器学习模型集成到Web应用是一个常见的需求。对于Spring Boot项目来说,可以采用多种方式来实现这一目标。下面介绍一种通过REST API调用来完成Spring Boot与豆包大模型集成的方法。 #### 创建控制器类处理请求 为了使前端能够方便地向后端发送数据并接收预测结果,在Spring Boot工程里定义一个新的Controller用于接受HTTP POST请求: ```java @RestController @RequestMapping("/api/v1/doubao") public class DoubaoModelController { @PostMapping("/predict") public ResponseEntity<PredictionResponse> predict(@RequestBody PredictionRequest request){ // 调用服务层方法获取预测结果 PredictionService predictionService = new PredictionServiceImpl(); PredictionResponse response = predictionService.predict(request); return ResponseEntity.ok(response); } } ``` 这里假设`PredictionRequest`和`PredictionResponse`分别是输入参数以及返回的结果对象[^1]。 #### 实现业务逻辑的服务接口 接着创建一个名为`PredictionService`的接口及其具体实现类`PredictionServiceImpl`,负责实际调用外部API或者加载本地部署好的豆包大模型来进行推理计算: ```java @Service public interface PredictionService { PredictionResponse predict(PredictionRequest request); } @Component class PredictionServiceImpl implements PredictionService{ private static final String DOUBAO_API_URL = "https://example.com/api/predict"; @Override public PredictionResponse predict(PredictionRequest request) { RestTemplate restTemplate = new RestTemplate(); HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); HttpEntity<String> entity = new HttpEntity<>(request.toString(), headers); ResponseEntity<PredictionResponse> response = restTemplate.exchange(DOUBAO_API_URL, HttpMethod.POST, entity, PredictionResponse.class); return response.getBody(); } } ``` 上述代码片段展示了如何利用`RestTemplate`工具发起对外部API服务器(`DOUBAO_API_URL`)的POST请求,并传入JSON格式的数据作为模型输入;最后解析响应体中的内容填充至自定义的对象结构内以便于后续使用。 请注意,这只是一个简单的例子,实际情况可能涉及到更复杂的配置项设置、错误处理机制设计等方面的工作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序猿小白菜

打赏换头发,BUG退散!✨

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值