RAG系列文章目录
第一章 Springboot RAG 一站式混合搜索方案
文章目录
前言
最近在做一个政策类查询的RAG方案,做成一站式可以快速使用的方案。
数据库是PG, PGVector作为向量数据库,采用Hybrid Search方法来同时匹配向量和其他字段。
项目采用Springboot 作为后端;大模型相关的,使用到的API有Chatgpt, moonshot, qwen,讯飞星火等不同厂家的方案。
该方案从产品方面来考虑,可扩展性,可便利性等没有太多考虑;从单个项目来说,算是一个可用的方案。
Spring AI 支持所有主要的模型提供商,如 OpenAI、Microsoft、Amazon、Google 和 Huggingface;国内的大模型还没有支持,国内大模型的API的返回,有几个是兼容OpenAI的,另外一些是不兼容的,需要做不少工作来完全兼容。这是后面可以优化的方向,做成一个统一的接口,便于系统维护和更多人的上手使用。
一、PG
系统的数据库采用PG, 文本也是放在一个text字段中,用PG自带的全文检索,同时把向量匹配也放到一起过滤,所以向量数据库采用PGVector。
一、PGVector
PGVector的安装有很多写的详细的过程,这里略过。
二、PG全文搜索
PostgreSQL 全文检索:PostgreSQL 自带的全文检索功能可以使用 tsvector 和 tsquery 数据类型,通过分词和倒排索引来实现语义搜索的基本需求。
Zhparser 分词插件:对于中文文本,可以使用 PostgreSQL 的 Zhparser 插件进行中文分词,结合全文检索功能实现语义搜索。
二、SpringBoot 结合PGVector
PGVector.java
系统里面使用SpringBoot, MyBatisPlus的方式来,目前MyBatisPlus本身并不支持PGVector。需要增加一个PGVector的方式,
PGVector有项目https://github.com/pgvector/pgvector-java.git, 实现过程中参考该项目来实现PGVector;(也可以参考SpringAI项目对PGVector支持)
package com.md.gpt.vector;
import org.postgresql.PGConnection;
import org.postgresql.util.ByteConverter;
import org.postgresql.util.PGBinaryObject;
import org.postgresql.util.PGobject;
import java.io.Serializable;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
public class PGvector extends PGobject implements PGBinaryObject, Serializable, Cloneable {
private float[] vec;
/**
* Constructor
*/
public PGvector() {
type = "vector";
}
/**
* Constructor
*
* @param v float array
*/
public PGvector(float[] v) {
this();
vec = v;
}
/**
* Constructor
*
* @param <T> number
* @param v list of numbers
*/
public <T extends Number> PGvector(List<T> v) {
this();
if (Objects.isNull(v)) {
vec = null;
} else {
vec = new float[v.size()];
int i = 0;
for (T f : v) {
vec[i++] = f.floatValue();
}
}
}
/**
* Constructor
*
* @param s text representation of a vector
* @throws SQLException exception
*/
public PGvector(String s) throws SQLException {
this();
setValue(s);
}
/**
* Sets the value from a text representation of a vector
*/
@Override
public void setValue(String s) throws SQLException {
if (s == null) {
vec = null;
} else {
String[] sp = s.substring(1, s.length() - 1).split(",");
vec = new float[sp.length];
for (int i = 0; i < sp.length; i++) {
vec[i] = Float.parseFloat(sp[i]);
}
}
}
/**
* Returns the text representation of a vector
*/
@Override
public String getValue() {
if (vec == null) {
return null;
} else {
return Arrays.toString(vec).replace(" ", "");
}
}
/**
* Returns the number of bytes for the binary representation
*/
@Override
public int lengthInBytes() {
return vec == null ? 0 : 4 + vec.length * 4;
}
/**
* Sets the value from a binary representation of a vector
*/
@Override
public void setByteValue(byte[] value, int offset) throws SQLException {
int dim = ByteConverter.int2(value, offset);
int unused = ByteConverter.int2(value, offset + 2);
if (unused != 0) {
throw new SQLException("expected unused to be 0");
}
vec = new float[dim];
for (int i = 0; i < dim; i++) {
vec[i] = ByteConverter.float4(value, offset + 4 + i * 4);
}
}
/**
* Writes the binary representation of a vector
*/
@Override
public void toBytes(byte[] bytes, int offset) {
if (vec == null) {
return;
}
// server will error on overflow due to unconsumed buffer
// could set to Short.MAX_VALUE for friendlier error message
ByteConverter.int2(bytes, offset, vec.length);
ByteConverter.int2(bytes, offset + 2, 0);
for (int i = 0; i < vec.length; i++) {
ByteConverter.float4(bytes, offset + 4 + i * 4, vec[i]);
}
}
/**
* Returns an array
*
* @return an array
*/
public float[] toArray() {
return vec;
}
/**
* Registers the vector type
*
* @param conn connection
* @throws SQLException exception
*/
public static void addVectorType(Connection conn) throws SQLException {
conn.unwrap(PGConnection.class).addDataType("vector", PGvector.class);
}
}
PGvectorTypeHandler.java
PGvectorTypeHandler 类主要如下
import com.md.gpt.util.FloatUtil;
import com.md.gpt.vector.PGvector;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedJdbcTypes;
import org.apache.ibatis.type.MappedTypes;
import java.sql.*;
import java.util.Arrays;
import static org.springframework.util.ObjectUtils.toObjectArray;
@Slf4j
@MappedTypes({PGvector.class})
public class PGvectorTypeHandler extends BaseTypeHandler<PGvector> {
@Override
public void setNonNullParameter(PreparedStatement ps, int i, PGvector parameter, JdbcType jdbcType) throws SQLException {
log.info("Getting PGvector result by column name: {}", parameter.toString());
Connection conn = ps.getConnection();
Float[] boxedArray = FloatUtil.toObjectArray(parameter.toArray());
Array sqlArray = conn.createArrayOf("float", boxedArray);
ps.setArray(i, sqlArray);
}
@Override
public PGvector getNullableResult(ResultSet rs, String columnName) throws SQLException {
log.info("Getting PGvector result by column name: {}", columnName);
Array array = rs.getArray(columnName);
if (array != null) {
Float[] javaArray = (Float[])array.getArray();
return new PGvector(FloatUtil.toPrimitiveArray(javaArray));
}
return null;
}
@Override
public PGvector getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
log.info("Getting PGvector result by column index: {}", columnIndex);
return (PGvector) rs.getObject(columnIndex);
}
@Override
public PGvector getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
log.info("Getting PGvector result by column index: {}", columnIndex);
return (PGvector) cs.getObject(columnIndex);
}
}
MyBatisConfig.java对应修改
@Bean
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
log.info("Registering PGvectorTypeHandler in sqlSessionFactory");
MybatisSqlSessionFactoryBean sqlSessionFactoryBean = new MybatisSqlSessionFactoryBean();
sqlSessionFactoryBean.setDataSource(dataSource);
// 注册 MyBatis Plus 拦截器
sqlSessionFactoryBean.setPlugins(mybatisPlusInterceptor());
// Set location of Mapper XML files
sqlSessionFactoryBean.setMapperLocations(
new PathMatchingResourcePatternResolver().getResources("classpath:/mapper/*.xml")
);
// 创建 MybatisConfiguration 实例并配置
MybatisConfiguration configuration = new MybatisConfiguration();
configuration.addMappers("com.md.gpt.mapper");
// configuration.addMapperLocation("classpath:mapper/*.xml");
// 例如,注册自定义类型处理器
configuration.getTypeHandlerRegistry().register(PGvectorTypeHandler.class);
sqlSessionFactoryBean.setConfiguration(configuration);
return sqlSessionFactoryBean.getObject();
}
@Bean
public TypeHandlerRegistry typeHandlerRegistry() {
log.info("Registering PGvectorTypeHandler in TypeHandlerRegistry");
TypeHandlerRegistry registry = new TypeHandlerRegistry();
// 注册自定义类型处理器
registry.register(PGvectorTypeHandler.class);
return registry;
}
PGvectorConfiguration.java
import com.md.gpt.mybatis.PGvectorTypeHandler;
import com.md.gpt.vector.PGvector;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.CommandLineRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
@Configuration
@Slf4j
public class PGvectorConfiguration {
@Bean
CommandLineRunner registerPGvectorType(DataSource dataSource) {
return args -> {
try (Connection conn = dataSource.getConnection()) {
log.info("Registering PGvector type with the database");
PGvector.addVectorType(conn); // This registers the 'vector' type
} catch (SQLException e) {
throw new RuntimeException("Failed to register PGvector type with the database", e);
}
};
}
}
FaguikuMapper.java
List<Faguiku> findNewClosestEmbeddings(@Param("embedding") PGvector embedding, String yearInfo,
String title, String keywords, String region, String docNumber,
String startDate, String endDate, String status, String fileType,
String docDanwei, String fileTaxType,int pageSize, int offset);
FaguikuMapper.xml
<select id="findNewClosestEmbeddings" resultMap="BaseResultMap">
SELECT id, title, doc_number, date_written, status, link, attachment_link, file_type, doc_danwei, file_tax_type, status_sort_order
FROM faguiku
<where>
<if test="yearInfo != null and yearInfo.trim() != ''">
title LIKE CONCAT('%', #{yearInfo}, '%') OR doc_number LIKE CONCAT('%', #{yearInfo}, '%')
</if>
<if test="title != null and title.trim() != ''">
AND to_tsvector('zh_cn', title) @@ plainto_tsquery('zh_cn', #{title})
</if>
<if test="keywords != null and keywords.trim() != ''">
AND to_tsvector('zh_cn', content2) @@ plainto_tsquery('zh_cn', #{keywords})
</if>
<if test="region != null and region.trim() != ''">
AND REPLACE(REPLACE(source, '省', ''), '市', '') = REPLACE(REPLACE(#{region}, '省', ''), '市', '')
</if>
<if test="docNumber != null and docNumber.trim() != ''">
AND doc_number LIKE CONCAT('%', #{docNumber}, '%')
</if>
<if test="status != null and status.trim() != ''">
AND status = #{status}
</if>
<if test="docDanwei != null and docDanwei.trim() != ''">
AND doc_danwei LIKE CONCAT('%', #{docDanwei}, '%')
</if>
<if test="fileTaxType != null and fileTaxType.trim() != ''">
AND file_tax_type = #{fileTaxType}
</if>
<if test="startDate != null and startDate.trim() != ''">
AND date_written >= #{startDate}
</if>
<if test="endDate != null and endDate.trim() != ''">
AND date_written <= #{endDate}
</if>
</where>
ORDER BY embeddings::vector <![CDATA[ <=> ]]> #{embedding}::vector, status_sort_order, date_written desc
LIMIT #{pageSize} OFFSET #{offset}
</select>
embeddings 通过Embedding API去获取,可以通过各个大模型的Embedding API,也可以通过部署Embeddings API, 根据huggingface中的排名来选择中文支持较好的Embedding模型
总结
这种方案,较好的考虑到了查询的方便性,对需要管理的文档,统一在数据库中管理;否则如果有成千上万篇文档,而不能有效的通过系统管理起来,那几乎很难维护了。通过Hybrid Search,结合传统数据库对于一些字段的完全匹配,结合全文搜索和向量搜索,得出来的结果,可以根据查询结果来调整,把符合条件的都过滤出来给用户做下一步分析使用。
有几个待进一步探讨探讨的问题:
1.文本分段,文章内容来源于很多地方,通过按照chunksize overlap来分段,然后再embedding, 有些自然段落被划分错了。通过调用大模型来划分,Token消耗比较多,时间比较久。目前使用的库里面,用到了3000万Token。而且有些划分还是不太合理,需要人工介入才能划分的合理。
2. Rerank模型,因为排序在sql里面做了,所以没有用rerank模型,准备按照rerank模型的算法看看能不能本地部署或者实现,来看看rerank能提高多少
3. 查询的文档比较多,最后给大模型来整理的话,一个是传入的长度有限制,不能把所有结果都传过去;二是太长的查询,消耗的也多。
4. 其他RAG方面的最新论文探索,比如Graph RAG等,后续进一步研究
5. 其他问题,欢迎加微信交流 rogerlzp