最近在学习SpringAI,但是Spring官方提供的InMemoryChatMemory是存储在内存中的,电脑关机后和AI的聊天记录就无了,想着将聊天记录存储在MySQL数据库中,就能保存聊天记录了。
在使用InMemoryChatMemory基础上改造成将聊天记录存储到MySQL可参照下方代码
一.数据库表的设计
CREATE TABLE chat_messages (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
conversation_id VARCHAR(255) NOT NULL,
message_type ENUM('USER', 'ASSISTANT') NOT NULL,
content TEXT NOT NULL,
created_at DATETIME NOT NULL,
INDEX idx_conversation_id (conversation_id)
);
id:主键自增
conversation_id:会话ID,区别不同的会话
message_type:消息类型,记录是用户的消息还是AI模型的
content:消息内容
created_at:记录创建时间
二.创建数据表实体类
import lombok.Data;
import java.time.LocalDateTime;
@Data
public class ChatMessageEntity {
private Long id;
private String conversationId;
private MessageType messageType;
private String content;
private LocalDateTime createdAt;
public enum MessageType {
USER, ASSISTANT
}
}
三.创建Mapper接口
import com.zry.ai.entity.ChatMemoryEntity.ChatMessageEntity;
import org.apache.ibatis.annotations.*;
import java.util.List;
@Mapper
public interface MessageMapper {
@Insert({
"<script>",
"INSERT INTO chat_messages (conversation_id, message_type, content, created_at)",
"VALUES ",
"<foreach collection='messages' item='msg' separator=','>",
"(#{msg.conversationId}, #{msg.messageType}, #{msg.content}, #{msg.createdAt})",
"</foreach>",
"</script>"
})
void insertMessages(@Param("messages") List<ChatMessageEntity> messages);
@Select("SELECT * FROM chat_messages " +
"WHERE conversation_id = #{conversationId} " +
"ORDER BY created_at DESC " +
"LIMIT #{lastN}")
List<ChatMessageEntity> findLastNMessages(@Param("conversationId") String conversationId,
@Param("lastN") int lastN);
@Delete("DELETE FROM chat_messages WHERE conversation_id = #{conversationId}")
void deleteByConversationId(String conversationId);
}
四.实现SpringAI的ChatMemery接口
ChatMemery接口:
实现ChatMemery接口:
import com.zry.ai.entity.ChatMemoryEntity.ChatMessageEntity;
import com.zry.ai.mapper.MessageMapper;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.stereotype.Component;
import java.time.LocalDateTime;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@Component
@RequiredArgsConstructor
public class MysqlChatMemory implements ChatMemory {
private final MessageMapper messageMapper;
@Override
public void add(String conversationId, List<Message> messages) {
List<ChatMessageEntity> entities = messages.stream()
.map(msg -> {
ChatMessageEntity entity = new ChatMessageEntity();
entity.setConversationId(conversationId);
entity.setContent(msg.getText());
if (msg instanceof UserMessage) {
entity.setMessageType(ChatMessageEntity.MessageType.USER);
} else if (msg instanceof AssistantMessage) {
entity.setMessageType(ChatMessageEntity.MessageType.ASSISTANT);
}
entity.setCreatedAt(LocalDateTime.now());
return entity;
})
.collect(Collectors.toList());
messageMapper.insertMessages(entities);
}
@Override
public List<Message> get(String conversationId, int lastN) {
List<ChatMessageEntity> entities = messageMapper.findLastNMessages(conversationId, lastN);
Collections.reverse(entities);
return entities.stream()
.map(entity -> {
switch (entity.getMessageType()) {
case USER:
return new UserMessage(entity.getContent());
case ASSISTANT:
return new AssistantMessage(entity.getContent());
default:
throw new IllegalArgumentException("未知的消息类型");
}
})
.collect(Collectors.toList());
}
@Override
public void clear(String conversationId) {
messageMapper.deleteByConversationId(conversationId);
}
}
在配置ChatClient时将MysqlChatMemery加入环绕增强