一、前言
通过参与“开源模型应用落地-业务整合系列篇”的学习,我们已经成功建立了基本的业务流程。然而,这只是迈出了万里长征的第一步。现在我们要对整个项目进行优化,以提高效率。我们计划利用线程池来加快处理速度,使用redis来实现排队需求,以及通过多级环境来减轻负载压力。这些优化措施将有助于我们进一步改进项目的性能和效果。
二、术语
2.1. 线程池
是一种用于线程管理的技术,它包含一组预先创建的线程,用于执行任务。线程池维护着一个任务队列,当有任务到达时,线程池中的线程会自动分配任务并执行。
线程池的主要目的是重用线程,避免频繁地创建和销毁线程带来的开销。通过使用线程池,可以在程序初始化时创建一组线程,并将任务提交给线程池进行处理,而不需要为每个任务都创建一个新的线程。这样可以有效地管理系统中的线程数量,控制并发度,提高系统的性能和资源利用率。
线程池通常包含以下几个关键组件:
- 任务队列(Task Queue):用于存储待执行的任务,通常是一个队列结构。当有新的任务到达时,会被添加到任务队列中。
- 线程池管理器(Thread Pool Manager):负责管理线程池的创建、销毁和线程的调度。它会监视任务队列的状态,并根据需要动态地创建或回收线程。
- 工作线程(Worker Threads):线程池中的线程,用于执行任务。它们会从任务队列中获取任务,并执行任务的处理逻辑。
三、前置条件
3.1. 已搭建WebSocket与AI服务调用链路
四、技术实现
4.1. 调整业务逻辑处理类
对于每次交互的chat对话,都需要经过以下步骤,包括但不限于:
- 对用户输入的内容进行自定义违规词检测
- 对用户输入的内容进行第三方在线违规词检测
- 对用户输入的内容进行组装成Prompt
- 对Prompt根据业务进行增强(完善prompt的内容)
- 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)
特别是调用第三方在线违规词检测,例如:某某云的内容安全审核服务,是非常耗时,会阻塞正常线程的执行,导致吞吐量的下降。
所以,我们就要对下面这块的处理逻辑进行调整,通过自定义线程池的方式,去处理核心的Chat交互流程
调整后:
4.2. 新增线程处理类
-
import io.netty.channel.ChannelHandlerContext;
-
import lombok.extern.slf4j.Slf4j;
-
import org.springframework.beans.factory.annotation.Autowired;
-
import org.springframework.stereotype.Component;
-
-
import java.util.List;
-
import java.util.concurrent.ExecutorService;
-
import java.util.concurrent.Executors;
-
-
@Component
-
@Slf4j
-
public
class
TaskUtils{
-
private
static
ExecutorService
executorService
= Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() *
2);
-
@Autowired
-
private AIChatUtils aiChatUtils;
-
-
public
void
execute
(AITaskReqMessage aiTaskReqMessage) {
-
-
executorService.execute(() -> {
-
Long
userId
= aiTaskReqMessage.getUserId();
-
-
if (
null == userId || (
long) userId <
10000) {
-
log.warn(
"用户身份标识有误!");
-
return;
-
}
-
-
ChannelHandlerContext
channelHandlerContext
= AbstractBusinessLogicHandler.getContextByUserId(userId);
-
-
if (channelHandlerContext !=
null) {
-
try {
-
aiChatUtils.chatStream(aiTaskReqMessage);
-
-
}
catch (Throwable exception) {
-
exception.printStackTrace();
-
}
-
}
-
});
-
}
-
-
public
static
void
destory
(){
-
executorService.shutdownNow();
-
executorService =
null;
-
}
-
-
}
4.3. 新增线程处理实体类
-
import lombok.Builder;
-
import lombok.Getter;
-
import lombok.Setter;
-
-
import java.util.List;
-
-
@Builder
-
@Setter
-
@Getter
-
public
class
AITaskReqMessage {
-
-
private String messageId;
-
private Long userId;
-
private String contents;
-
private List<ChatContext> history;
-
}
五、测试
在线测试方式:WebSocket在线测试工具
5.1. 建立连接
5.2. 业务初始化
服务端输出:
5.3. 业务对话
服务端输出
5.4. 关闭连接
六、附带说明
6.1. 可以使用jmeter进行websocket压测,以评估各项性能指标是否符合预期(下一篇)
6.2. BusinessHandler完整代码
-
import com.alibaba.fastjson.JSON;
-
import io.netty.channel.ChannelHandler;
-
import lombok.extern.slf4j.Slf4j;
-
import io.netty.channel.ChannelHandlerContext;
-
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
-
import org.apache.commons.lang3.StringUtils;
-
import org.springframework.beans.factory.annotation.Autowired;
-
import org.springframework.stereotype.Component;
-
-
import java.util.List;
-
-
-
/**
-
* @Description: 处理消息的handler
-
*/
-
@Slf4j
-
@ChannelHandler.Sharable
-
@Component
-
public
class
BusinessHandler
extends
AbstractBusinessLogicHandler<TextWebSocketFrame> {
-
@Autowired
-
private TaskUtils taskExecuteUtils;
-
-
@Override
-
public
void
handlerAdded
(ChannelHandlerContext ctx)
throws Exception {
-
String
channelId
= ctx.channel().id().asShortText();
-
log.info(
"add client,channelId:{}", channelId);
-
}
-
-
@Override
-
public
void
handlerRemoved
(ChannelHandlerContext ctx)
throws Exception {
-
String
channelId
= ctx.channel().id().asShortText();
-
log.info(
"remove client,channelId:{}", channelId);
-
}
-
-
-
@Override
-
protected
void
channelRead0
(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)
-
throws Exception {
-
// 获取客户端传输过来的消息
-
String
content
= textWebSocketFrame.text();
-
// 兼容在线测试
-
if (StringUtils.equals(content,
"PING")) {
-
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
-
.respTime(String.valueOf(System.currentTimeMillis()))
-
.msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
-
.contents(
"心跳测试,很高兴收到你的心跳包")
-
.build());
-
-
return;
-
}
-
log.info(
"接收到客户端发送的信息: {}", content);
-
-
Long userIdForReq;
-
String
msgType
=
"";
-
String
contents
=
"";
-
-
try {
-
ApiReqMessage
apiReqMessage
= JSON.parseObject(content, ApiReqMessage.class);
-
msgType = apiReqMessage.getMsgType();
-
contents = apiReqMessage.getContents();
-
-
-
userIdForReq = apiReqMessage.getUserId();
-
// 用户身份标识校验
-
if (
null == userIdForReq || (
long) userIdForReq <=
10000) {
-
ApiRespMessage
apiRespMessage
= ApiRespMessage.builder().code(String.valueOf(StatusCode.SYSTEM_ERROR.getCode()))
-
.respTime(String.valueOf(System.currentTimeMillis()))
-
.contents(
"用户身份标识有误!")
-
.msgType(String.valueOf(MsgType.SYSTEM.getCode()))
-
.build();
-
buildResponseAndClose(channelHandlerContext, apiRespMessage);
-
return;
-
}
-
-
-
if (StringUtils.equals(msgType, String.valueOf(MsgType.CHAT.getCode()))) {
-
// 对用户输入的内容进行自定义违规词检测
-
// 对用户输入的内容进行第三方在线违规词检测
-
// 对用户输入的内容进行组装成Prompt
-
// 对Prompt根据业务进行增强(完善prompt的内容)
-
// 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)
-
// ...
-
String
messageId
= apiReqMessage.getMessageId();
-
List<ChatContext> history = apiReqMessage.getHistory();
-
AITaskReqMessage
aiTaskReqMessage
= AITaskReqMessage.builder().messageId(messageId).userId(userIdForReq).contents(contents).history(history).build();
-
taskExecuteUtils.execute(aiTaskReqMessage);
-
-
-
}
else
if (StringUtils.equals(msgType, String.valueOf(MsgType.INIT.getCode()))) {
-
//一、业务黑名单检测(多次违规,永久锁定)
-
-
//二、账户锁定检测(临时锁定)
-
-
//三、多设备登录检测
-
-
//四、剩余对话次数检测
-
-
//检测通过,绑定用户与channel之间关系
-
addChannel(channelHandlerContext, userIdForReq);
-
String
respMessage
=
"用户标识: " + userIdForReq +
" 登录成功";
-
-
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
-
.respTime(String.valueOf(System.currentTimeMillis()))
-
.msgType(String.valueOf(MsgType.INIT.getCode()))
-
.contents(respMessage)
-
.build());
-
-
}
else
if (StringUtils.equals(msgType, String.valueOf(MsgType.HEARTBEAT.getCode()))) {
-
-
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
-
.respTime(String.valueOf(System.currentTimeMillis()))
-
.msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
-
.contents(
"心跳测试,很高兴收到你的心跳包")
-
.build());
-
}
-
else {
-
log.info(
"用户标识: {}, 消息类型有误,不支持类型: {}", userIdForReq, msgType);
-
}
-
-
-
}
catch (Exception e) {
-
log.warn(
"【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);
-
// 异常返回
-
return;
-
}
-
-
}
-
-
}
6.3. AIChatUtils完整代码
-
import com.alibaba.fastjson.JSON;
-
import lombok.extern.slf4j.Slf4j;
-
import okhttp3.MediaType;
-
import okhttp3.Request;
-
import okhttp3.RequestBody;
-
import okhttp3.Response;
-
import org.apache.commons.lang3.StringUtils;
-
import org.springframework.beans.factory.annotation.Autowired;
-
import org.springframework.stereotype.Component;
-
-
import java.io.ByteArrayOutputStream;
-
import java.io.InputStream;
-
import java.nio.charset.StandardCharsets;
-
import java.security.MessageDigest;
-
import java.util.List;
-
import java.util.Objects;
-
-
@Slf4j
-
@Component
-
public
class
AIChatUtils {
-
@Autowired
-
private AIConfig aiConfig;
-
-
private Request
buildRequest
(Long userId, String prompt)
throws Exception {
-
//创建一个请求体对象(body)
-
MediaType
mediaType
= MediaType.parse(
"application/json");
-
RequestBody
requestBody
= RequestBody.create(mediaType, prompt);
-
-
return buildHeader(userId,
new
Request.Builder().post(requestBody))
-
.url(aiConfig.getUrl()).build();
-
}
-
-
private Request.Builder
buildHeader
(Long userId, Request.Builder builder)
throws Exception {
-
return builder
-
.addHeader(
"Content-Type",
"application/json")
-
.addHeader(
"userId", String.valueOf(userId))
-
.addHeader(
"secret",generateSecret(userId))
-
}
-
-
-
-
/**
-
* 生成请求密钥
-
*
-
* @param userId 用户ID
-
* @return
-
*/
-
private String
generateSecret
(Long userId)
throws Exception {
-
String
key
= aiConfig.getServerKey();
-
String
content
= key + userId + key;
-
-
MessageDigest
digest
= MessageDigest.getInstance(
"SHA-256");
-
byte[] hash = digest.digest(content.getBytes(StandardCharsets.UTF_8));
-
-
StringBuilder
hexString
=
new
StringBuilder();
-
for (
byte b : hash) {
-
String
hex
= Integer.toHexString(
0xff & b);
-
if (hex.length() ==
1) {
-
hexString.append(
'0');
-
}
-
hexString.append(hex);
-
}
-
return hexString.toString();
-
}
-
-
public String
chatStream
(AITaskReqMessage aiTaskReqMessage)
throws Exception {
-
-
String
messageId
= aiTaskReqMessage.getMessageId();
-
Long
userId
= aiTaskReqMessage.getUserId();
-
String
contents
= aiTaskReqMessage.getContents();
-
List<ChatContext> history = aiTaskReqMessage.getHistory();
-
-
if(StringUtils.isEmpty(contents) || StringUtils.isBlank(contents)){
-
log.warn(
"用户输入内容不能为空!");
-
return
null;
-
}
-
-
//定义请求的参数
-
String
prompt
= JSON.toJSONString(AIChatReqVO.init(contents, history));
-
log.info(
"【AIChatUtils】调用AI聊天,用户({}),prompt:{}", userId, prompt);
-
-
//创建一个请求对象
-
Request
request
= buildRequest(userId, prompt);
-
-
InputStream
is
=
null;
-
try {
-
-
// 从线程池获取http请求并执行
-
Response
response
=OkHttpUtils.getInstance(aiConfig).getOkHttpClient().newCall(request).execute();
-
-
// 响应结果
-
StringBuffer
resultBuff
=
new
StringBuffer();
-
//正常返回
-
if (response.code() ==
200) {
-
//打印返回的字符数据
-
is = response.body().byteStream();
-
byte[] bytes =
new
byte[
1024];
-
-
int
len
= is.read(bytes);
-
while (len != -
1) {
-
ByteArrayOutputStream
outputStream
=
new
ByteArrayOutputStream();
-
outputStream.write(bytes,
0, len);
-
outputStream.flush();
-
// 本轮读取到的数据
-
String
result
=
new
String(outputStream.toByteArray(), StandardCharsets.UTF_8);
-
resultBuff.append(result);
-
-
len = is.read(bytes);
-
-
// 将数据逐个传输给用户
-
AbstractBusinessLogicHandler.pushChatMessageForUser(userId, result);
-
}
-
-
// 正常响应
-
return resultBuff.toString();
-
}
-
else {
-
String
result
= response.body().string();
-
log.warn(
"处理异常,异常描述:{}",result);
-
}
-
}
catch (Throwable e) {
-
log.error(
"【AIChatUtils】消息({})调用AI聊天 chatStream 异常,异常消息:{}", messageId, e.getMessage(), e);
-
-
}
finally {
-
if (!Objects.isNull(is)) {
-
try {
-
is.close();
-
}
catch (Exception e) {
-
e.printStackTrace();
-
}
-
}
-
}
-
return
null;
-
}
-
-
-
}