package com.msxf.pai.agent.workflow.application.listeners;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.ObjectUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.ObjectUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.google.gson.*;
import com.msxf.eyas.thread.MsxfRunnable;
import com.msxf.pai.agent.agents.application.service.autoagent.AutoAgentChatItemDetailServiceImpl;
import com.msxf.pai.agent.agents.application.service.autoagent.AutoAgentChatItemServiceImpl;
import com.msxf.pai.agent.agents.application.service.autoagent.AutoAgentChatServiceImpl;
import com.msxf.pai.agent.agents.domain.po.AutoAgentChat;
import com.msxf.pai.agent.agents.domain.po.AutoAgentChatItem;
import com.msxf.pai.agent.agents.domain.po.AutoAgentChatItemDetail;
import com.msxf.pai.agent.common.entity.constant.MessageKeyConstants;
import com.msxf.pai.agent.common.thread.ExecutorUtil;
import com.msxf.pai.agent.common.utils.MessageUtils;
import com.msxf.pai.agent.common.utils.SpringUtils;
import com.msxf.pai.agent.common.utils.UuidUtil;
import com.msxf.pai.agent.serving.application.util.SSEUtils;
import com.msxf.pai.agent.serving.domain.enums.SseResponseEventEnum;
import com.msxf.pai.agent.workflow.application.client.dto.autoagent.*;
import com.msxf.pai.agent.workflow.domain.enums.AutoAgentChatTypeEnum;
import com.msxf.pai.common.domain.dto.SessionUserInfo;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import javax.annotation.Nullable;
import java.util.*;
import java.util.concurrent.*;
@Slf4j
public class AutoAgentChatSSEListener extends EventSourceListener {
private final SessionUserInfo userInfo;
private final AutoAgentData dto;
private final Boolean chatType;
private final String reqId;
private final CountDownLatch latch;
@Getter
private Integer runStatus = 1;
private final Map<String, List<AutoAgentRagVO>> ragMapLists = new ConcurrentHashMap<>();
private final List<AutoAgentRagVO> ragVOLists = new CopyOnWriteArrayList<>();
private final List<String> events = new CopyOnWriteArrayList<>();
private long eventStartTime = System.currentTimeMillis(); // 添加时间戳变量
private long totalTime; //总耗时
private final ScheduledExecutorService scheduler;
private final String chatItemId = UuidUtil.getUUID(); //默认ai回答id
public AutoAgentChatSSEListener(SessionUserInfo userInfo, AutoAgentData dto, CountDownLatch latch, ScheduledExecutorService scheduler) {
this.reqId = dto.getAgentData().getRequestId();
this.chatType = dto.getAgentData().getDraftMode();
this.userInfo = userInfo;
this.dto = dto;
this.scheduler = scheduler;
this.latch = latch;
}
@Override
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
super.onOpen(eventSource, response);
log.info("connect to rag success:{}", response);
}
@Override
public void onEvent(@NotNull EventSource eventSource, String id, String type, @NotNull String data) {
log.info("auto agent chat rag send requestId:{}, id:{},type:{},data {}:", reqId, id, type, data);
if (StringUtils.isNotBlank(data)) {
events.add(data);
AutoAgentRagVO agentRagVO = new AutoAgentRagVO();
long currentTime = System.currentTimeMillis();
// 计算间隔时间
long time = currentTime - eventStartTime;
// 更新时间戳
eventStartTime = currentTime;
totalTime = totalTime + time;
agentRagVO.setTime(time);
agentRagVO.setChatItemId(chatItemId);
JsonObject jsonObject = JsonParser.parseString(data).getAsJsonObject();
JsonElement content = null;
if (jsonObject.has(AutoAgentChatTypeEnum.CONTENT.getValue()) && !jsonObject.get(AutoAgentChatTypeEnum.CONTENT.getValue()).isJsonNull()) {
content = jsonObject.get(AutoAgentChatTypeEnum.CONTENT.getValue());
}
getResult(jsonObject, agentRagVO);
if (ObjectUtils.isNotEmpty(agentRagVO.getError_message()) || ObjectUtils.isNotEmpty(agentRagVO.getError_code())) {
runStatus = 0;
if (dto.getStream()) {
SSEUtils.pubMsg(reqId, AutoAgentChatTypeEnum.ERROR.getValue(), JSONObject.toJSONString(agentRagVO.getError_message()));
}
log.info("error_message : {}", JSONObject.toJSONString(agentRagVO.getError_message()));
} else {
CustomMetadata customMetadata = null;
if (jsonObject.has(AutoAgentChatTypeEnum.CUSTOM_METADATA.getValue()) &&
!jsonObject.get(AutoAgentChatTypeEnum.CUSTOM_METADATA.getValue()).isJsonNull()) {
JsonObject customMetadataJsonObject = jsonObject.get(AutoAgentChatTypeEnum.CUSTOM_METADATA.getValue()).getAsJsonObject();
// 使用Gson将JsonObject反序列化为customMetadata类实例
Gson gson = new Gson();
customMetadata = gson.fromJson(customMetadataJsonObject, CustomMetadata.class);
}
if (jsonObject.has(AutoAgentChatTypeEnum.TIMESTAMP.getValue()) &&
!jsonObject.get(AutoAgentChatTypeEnum.TIMESTAMP.getValue()).isJsonNull()) {
long asLong = jsonObject.get(AutoAgentChatTypeEnum.TIMESTAMP.getValue()).getAsLong();
agentRagVO.setTimestamp(asLong);
}
agentRagVO.setCustom_metadata(customMetadata);
writeRagVo(jsonObject, content, agentRagVO);
agentRagVO.setId(ObjectUtils.isEmpty(customMetadata.getId()) ? UuidUtil.getUUID() : customMetadata.getId());
if (dto.getStream()) {
List<RagFunctionResponse> functionResponse = agentRagVO.getFunction_response();
if (CollectionUtils.isNotEmpty(functionResponse)) {
List<String> errorMessage = new ArrayList<>();
for (RagFunctionResponse ragFunctionResponse : functionResponse) {
String error_message = ragFunctionResponse.getResponse().getError_message();
if (ObjectUtils.isNotEmpty(error_message)) {
errorMessage.add(error_message);
}
}
if (CollectionUtils.isNotEmpty(errorMessage)) {
SSEUtils.pubMsg(reqId, AutoAgentChatTypeEnum.TOOL_ERROR.getValue(), JSONObject.toJSONString(agentRagVO));
} else {
pushMessage(customMetadata, agentRagVO);
}
} else {
pushMessage(customMetadata, agentRagVO);
}
}
}
ragVOLists.add(agentRagVO);
ragMapLists.put(reqId, ragVOLists);
}
}
private void pushMessage(CustomMetadata customMetadata, AutoAgentRagVO agentRagVO) {
if (!ObjectUtils.isEmpty(customMetadata.getIs_runner_start()) && customMetadata.getIs_runner_start()) {
SSEUtils.pubMsg(reqId, AutoAgentChatTypeEnum.START.getValue(), JSONObject.toJSONString(agentRagVO));
} else if (!ObjectUtils.isEmpty(customMetadata.getIs_runner_final()) && customMetadata.getIs_runner_final()) {
SSEUtils.pubMsg(reqId, AutoAgentChatTypeEnum.END.getValue(), JSONObject.toJSONString(agentRagVO));
} else {
SSEUtils.pubMsg(reqId, customMetadata.getType(), JSONObject.toJSONString(agentRagVO));
}
}
@Override
public void onClosed(@NotNull EventSource eventSource) {
scheduler.shutdown();
if (!dto.getStream()) {
latch.countDown();
} else {
SSEUtils.complete(reqId);
}
log.info("****** : sse close : *******");
List<AutoAgentRagVO> ragVO = getRagVO(reqId);
if (!CollectionUtils.isEmpty(ragVO) && !chatType) {
CompletableFuture.runAsync(new MsxfRunnable(() -> {
try {
asyncSaveChat(ragVO);
} catch (Exception e) {
log.error("save autoagent chat data fail ", e);
}
}), ExecutorUtil.antoAgentChatExecutor);
}
}
@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t,
@Nullable Response response) {
scheduler.shutdown();
String s = t != null ? t.getMessage() : response != null ? response.toString() : "";
runStatus = 0;
String msg = MessageUtils.getMessage(MessageKeyConstants.MESSAGE_API_APP_SERVING_DSLSSELISTENER_QINGQIU_DUIHUA_JIEKOU_YICHANG_YICHANG_XINXI) + s;
log.error("response : {}", response);
JSONObject out = new JSONObject();
out.put("message", msg);
if (!dto.getStream()) {
AutoAgentRagVO agentRagVO = new AutoAgentRagVO();
agentRagVO.setError_message(msg);
agentRagVO.setError_code("600");
ragVOLists.add(agentRagVO);
ragMapLists.put(reqId, ragVOLists);
latch.countDown();
throw new RuntimeException(out.toString());
} else {
SSEUtils.pubMsg(reqId, SseResponseEventEnum.ERROR.getValue(), out.toString());
SSEUtils.complete(reqId);
}
ragVOLists.clear();
ragMapLists.clear();
}
private void asyncSaveChat(List<AutoAgentRagVO> ragVO) {
String chatId = dto.getAgentData().getChatId();
AutoAgentChatServiceImpl autoAgentChatService = (AutoAgentChatServiceImpl) SpringUtils.getBean("autoAgentChatServiceImpl");
AutoAgentChat autoAgentChat = autoAgentChatService.findByChatId(chatId);
if (ObjectUtils.isEmpty(autoAgentChat)) {
String value = dto.getAgentData().getQuestion();
if (value.length() > 30) {
value = value.substring(0, 30) + "...";
}
AutoAgentChat agentChat = AutoAgentChat.builder()
.agentId(dto.getAgentId())
.chatId(chatId)
.chatTitle(value)
.chatSource(dto.getSource())
.createBy(String.valueOf(userInfo.getUserId()))
.updateBy(String.valueOf(userInfo.getUserId()))
.userId(String.valueOf(userInfo.getUserId()))
.userName(userInfo.getUserName())
.createTime(new Date())
.updateTime(new Date())
.teamId(String.valueOf(userInfo.getOrgId()))
.tenantId(userInfo.getTenantCode())
.teamName(userInfo.getOrgCode())
.runStatus(runStatus)
.agentVersion(dto.getVersion())
.build();
// 保存对话表
autoAgentChatService.save(agentChat);
}
saveItem(dto, userInfo, ragVO, autoAgentChat);
}
private void saveItem(AutoAgentData dto, SessionUserInfo userInfo, List<AutoAgentRagVO> ragVO, AutoAgentChat autoAgentChat) {
AutoAgentChatItemServiceImpl autoAgentChatItemService = (AutoAgentChatItemServiceImpl) SpringUtils.getBean("autoAgentChatItemServiceImpl");
AutoAgentChatItemDetailServiceImpl autoAgentChatItemDetailService = (AutoAgentChatItemDetailServiceImpl) SpringUtils.getBean("autoAgentChatItemDetailServiceImpl");
String chatId = dto.getAgentData().getChatId();
List<AutoAgentChatItem> items = new ArrayList<>();
List<AutoAgentChatItemDetail> itemDetails = com.google.common.collect.Lists.newArrayList();
String chatItemId = UuidUtil.getUUID();
String traceId = dto.getAgentData().getRequestId();
AutoAgentChatItem itemUser = AutoAgentChatItem.builder()
.agentId(dto.getAgentId())
.chatId(chatId)
.chatObj("Human")
.chatItemId(chatItemId)
.traceId(traceId)
.chatValue(dto.getAgentData().getQuestion())
.createBy(String.valueOf(userInfo.getUserId()))
.updateBy(String.valueOf(userInfo.getUserId()))
.userId(String.valueOf(userInfo.getUserId()))
.userName(userInfo.getUserName())
.createTime(new Date())
.updateTime(new Date())
.teamId(String.valueOf(userInfo.getOrgId()))
.tenantId(userInfo.getTenantCode())
.teamName(userInfo.getOrgCode())
.globalVariables(dto.getVariables())
.runningTime(totalTime)
.build();
if (CollectionUtils.isNotEmpty(dto.getFileList())) {
itemUser.setChatFileInfo(JSON.parseArray(JSON.toJSONString(dto.getFileList())));
}
items.add(itemUser);
for (AutoAgentRagVO autoAgentRagVO : ragVO) {
CustomMetadata customMetadata = autoAgentRagVO.getCustom_metadata();
if (!ObjectUtils.isEmpty(customMetadata) &&
!ObjectUtils.isEmpty(customMetadata.getIs_runner_final()) &&
customMetadata.getIs_runner_final()) {
String itemId = autoAgentRagVO.getChatItemId();
AutoAgentChatItem itemAi = BeanUtil.copyProperties(itemUser, AutoAgentChatItem.class);
itemAi.setChatObj("AI");
itemAi.setChatValue(autoAgentRagVO.getFinalResult());
itemAi.setChatItemId(itemId);
items.add(itemAi);
}
}
if (ragVO.size() == 1) {
AutoAgentRagVO agentRagVO = ragVO.get(0);
if (ObjectUtil.isNotEmpty(agentRagVO.getError_message())) {
String itemId = agentRagVO.getChatItemId();
AutoAgentChatItem itemAi = BeanUtil.copyProperties(itemUser, AutoAgentChatItem.class);
itemAi.setChatObj("FAIL");
itemAi.setChatValue(agentRagVO.getError_message());
itemAi.setChatItemId(itemId);
items.add(itemAi);
}
}
if (CollectionUtils.isNotEmpty(events)) {
for (int i = 0; i < events.size(); i++) {
AutoAgentChatItemDetail detail = createAutoAgentChatDetail(userInfo, traceId, chatItemId);
detail.setDetailData(events.get(i));
detail.setExecuteTime(ObjectUtils.isEmpty(ragVO.get(i).getTime()) ? 0L : ragVO.get(i).getTime());
itemDetails.add(detail);
}
}
autoAgentChatItemService.saveBatch(items);
autoAgentChatItemDetailService.saveBatch(itemDetails);
if (ObjectUtil.isNotEmpty(autoAgentChat)) {
autoAgentChat.setUpdateTime(new Date());
autoAgentChat.setUpdateBy(String.valueOf(userInfo.getUserId()));
}
events.clear();
}
@NotNull
private static AutoAgentChatItemDetail createAutoAgentChatDetail(SessionUserInfo userInfo, String traceId, String chatItemId) {
AutoAgentChatItemDetail detail = new AutoAgentChatItemDetail();
detail.setDetailId(IdUtil.fastSimpleUUID());
detail.setCreateBy(String.valueOf(userInfo.getUserId()));
detail.setCreateTime(new Date());
detail.setUpdateTime(new Date());
detail.setUpdateBy(String.valueOf(userInfo.getUserId()));
detail.setItemId(chatItemId);
detail.setTraceId(traceId);
return detail;
}
public List<AutoAgentRagVO> getRagVO(String reqId) {
return CollectionUtils.isEmpty(ragMapLists.get(reqId)) ? Collections.emptyList() : ragMapLists.get(reqId);
}
private static void writeRagVo(JsonObject jsonObject, JsonElement content, AutoAgentRagVO agentRagVO) {
if (jsonObject.has(AutoAgentChatTypeEnum.CONTENT.getValue()) && !content.isJsonNull()) {
JsonArray parts = content.getAsJsonObject().get(AutoAgentChatTypeEnum.PARTS.getValue()).getAsJsonArray();
List<RagFunctionCall> functionCall = new CopyOnWriteArrayList<>();
List<RagFunctionResponse> functionResponse = new CopyOnWriteArrayList<>();
parts.forEach(o -> {
JsonObject part = o.getAsJsonObject();
if (!part.get(AutoAgentChatTypeEnum.FUNCTION_CALL.getValue()).isJsonNull()) {
JsonObject asJsonObject = part.get(AutoAgentChatTypeEnum.FUNCTION_CALL.getValue()).getAsJsonObject();
RagFunctionCall functionCall1 = getRagFunctionCall(asJsonObject);
functionCall.add(functionCall1);
}
if (!part.get(AutoAgentChatTypeEnum.FUNCTION_RESPONSE.getValue()).isJsonNull()) {
RagFunctionResponse ragFunctionResponse = getRagFunctionResponse(part);
functionResponse.add(ragFunctionResponse);
}
if (!part.get(AutoAgentChatTypeEnum.TEXT.getValue()).isJsonNull() &&
(!part.get(AutoAgentChatTypeEnum.THOUGHT.getValue()).isJsonNull() &&
part.get(AutoAgentChatTypeEnum.THOUGHT.getValue()).getAsBoolean())) {
agentRagVO.setThinkResult(part.get(AutoAgentChatTypeEnum.TEXT.getValue()).getAsString());
}
if (!part.get(AutoAgentChatTypeEnum.TEXT.getValue()).isJsonNull() &&
part.get(AutoAgentChatTypeEnum.THOUGHT.getValue()).isJsonNull()) {
agentRagVO.setFinalResult(part.get(AutoAgentChatTypeEnum.TEXT.getValue()).getAsString());
}
});
Optional.of(functionCall).ifPresent(agentRagVO::setFunction_call);
Optional.of(functionResponse).ifPresent(agentRagVO::setFunction_response);
}
}
@NotNull
private static RagFunctionCall getRagFunctionCall(JsonObject asJsonObject) {
RagFunctionCall functionCall1 = new RagFunctionCall();
String id = asJsonObject.get("id").isJsonNull() ? null : asJsonObject.get("id").getAsString();
String name = asJsonObject.get("name").isJsonNull() ? null : asJsonObject.get("name").getAsString();
JsonObject args = asJsonObject.get("args").isJsonNull() ? null : asJsonObject.get("args").getAsJsonObject();
functionCall1.setArgs(args != null ? args.toString() : "");
functionCall1.setName(name);
functionCall1.setId(id);
return functionCall1;
}
@NotNull
private static RagFunctionResponse getRagFunctionResponse(JsonObject part) {
RagFunctionResponse ragFunctionResponse = new RagFunctionResponse();
RagResponse ragResponse = new RagResponse();
JsonObject asJsonObject = part.get(AutoAgentChatTypeEnum.FUNCTION_RESPONSE.getValue()).getAsJsonObject();
String id = asJsonObject.get("id").isJsonNull() ? null : asJsonObject.get("id").getAsString();
String name = asJsonObject.get("name").isJsonNull() ? null : asJsonObject.get("name").getAsString();
String scheduling = asJsonObject.get("scheduling").isJsonNull() ? null : asJsonObject.get("scheduling").getAsString();
String willContinue = asJsonObject.get("will_continue").isJsonNull() ? null : asJsonObject.get("will_continue").getAsString();
JsonObject response = asJsonObject.get(AutoAgentChatTypeEnum.RESPONSE.getValue()).isJsonNull() ? null : asJsonObject.get("response").getAsJsonObject();
String errorMessage = null;
String result = null;
String toolType = null;
if (response != null) {
errorMessage = response.get(AutoAgentChatTypeEnum.ERROR_MESSAGE.getValue()).isJsonNull() ? null : response.get("error_message").getAsString();
result = response.get("result").isJsonNull() ? null : response.get("result").getAsString();
toolType = response.get("tool_type").isJsonNull() ? null : response.get("tool_type").getAsString();
}
ragFunctionResponse.setId(id);
ragFunctionResponse.setName(name);
ragFunctionResponse.setScheduling(scheduling);
ragFunctionResponse.setWill_continue(willContinue);
ragResponse.setError_message(errorMessage);
ragResponse.setResult(result);
ragResponse.setTool_type(toolType);
ragFunctionResponse.setResponse(ragResponse);
return ragFunctionResponse;
}
private static void getResult(JsonObject jsonObject, AutoAgentRagVO agentRagVO) {
String errorCode = jsonObject.has(AutoAgentChatTypeEnum.ERROR_CODE.getValue()) &&
!jsonObject.get(AutoAgentChatTypeEnum.ERROR_CODE.getValue()).isJsonNull() ?
jsonObject.get(AutoAgentChatTypeEnum.ERROR_CODE.getValue()).getAsString() : null;
agentRagVO.setError_code(errorCode);
String error_message = jsonObject.has(AutoAgentChatTypeEnum.ERROR_MESSAGE.getValue()) &&
!jsonObject.get(AutoAgentChatTypeEnum.ERROR_MESSAGE.getValue()).isJsonNull() ?
jsonObject.get(AutoAgentChatTypeEnum.ERROR_MESSAGE.getValue()).getAsString() : null;
agentRagVO.setError_message(error_message);
}
}
这个是我的业务代码,请帮我优化一下代码,减少重复代码并将解析工具统一为gson
最新发布