- maven引入依赖
<dependency>
<groupId>com.volcengine</groupId>
<artifactId>volcengine-java-sdk-ark-runtime</artifactId>
<version>LATEST</version>
</dependency>
- 代码
AI实现类
package com.hpp.common.AIChat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import com.google.gson.Gson;
import com.hpp.common.redisson.RedisCache;
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import com.volcengine.ark.runtime.service.ArkService;
@Component
public class AIChatDouBao {
private static final Logger LOGGER = LoggerFactory.getLogger(AIChatDouBao.class);
@Autowired
private RedisCache redisCache;
private Gson gson = new Gson();
private String apiKey = "xxxxxxxxxxxxxxxx";
public SseEmitter conversation(String question) {
// 创建ArkService实例
ArkService arkService = ArkService.builder().apiKey(apiKey).build();
// 初始化消息列表
List<ChatMessage> chatMessages = new ArrayList<>();
// 创建用户消息
ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER) // 设置消息角色为用户
.content(question) // 设置消息内容
.build();
// 将用户消息添加到消息列表
chatMessages.add(userMessage);
// 创建聊天完成请求
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder().model("ep-xxxxxxxxxxxx")// 需要替换为您的推理接入点ID
.messages(chatMessages) // 设置消息列表
.build();
// 发送聊天完成请求并打印响应
// 获取响应并打印每个选择的消息内容
SseEmitter emitter = new SseEmitter();
try {
new Thread(() -> {
arkService.streamChatCompletion(chatCompletionRequest).doOnError(Throwable::printStackTrace)
.doFinally(() -> {
LOGGER.info("完成");
emitter.complete();
}).blockingForEach(choice -> {
if (choice.getChoices().size() > 0) {
ChatMessage message = choice.getChoices().get(0).getMessage();
// 判断是否触发深度推理,触发则打印模型输出的思维链内容
if (message.getReasoningContent() != null && !message.getReasoningContent().isEmpty()) {
LOGGER.info(message.getReasoningContent());
emitter.send(message.getReasoningContent());
}
// 打印模型输出的回答内容
emitter.send(message.getContent().toString());
LOGGER.info(message.getContent().toString());
}
});
}).start();
} catch (Exception e) {
LOGGER.error("ai回复异常", e);
emitter.completeWithError(e); // 发生错误时,通知客户端
} finally {
}
return emitter;
}
public String getAccessToken() throws IOException {
return "";
}
}
controller层
@RestController
@RequestMapping(value = "chat")
public class ChatController {
private static final Logger LOGGER = LoggerFactory.getLogger(ChatController.class);
@Autowired
private AIChatDouBao aiChatDouBao;
@GetMapping("/event-stream")
public SseEmitter streamEvents(HttpServletRequest req, @RequestParam String question) {
return aiChatDouBao.conversation(question);
}
}