Springboot RAG 一站式混合搜索方案

RAG系列文章目录

第一章 Springboot RAG 一站式混合搜索方案



前言

最近在做一个政策类查询的RAG方案,做成一站式可以快速使用的方案。
数据库是PG, PGVector作为向量数据库,采用Hybrid Search方法来同时匹配向量和其他字段。
项目采用Springboot 作为后端;大模型相关的,使用到的API有Chatgpt, moonshot, qwen,讯飞星火等不同厂家的方案。
该方案从产品方面来考虑,可扩展性,可便利性等没有太多考虑;从单个项目来说,算是一个可用的方案。
Spring AI 支持所有主要的模型提供商,如 OpenAI、Microsoft、Amazon、Google 和 Huggingface;国内的大模型还没有支持,国内大模型的API的返回,有几个是兼容OpenAI的,另外一些是不兼容的,需要做不少工作来完全兼容。这是后面可以优化的方向,做成一个统一的接口,便于系统维护和更多人的上手使用。

一、PG

系统的数据库采用PG, 文本也是放在一个text字段中,用PG自带的全文检索,同时把向量匹配也放到一起过滤,所以向量数据库采用PGVector。

一、PGVector

PGVector的安装有很多写的详细的过程,这里略过。

二、PG全文搜索

PostgreSQL 全文检索:PostgreSQL 自带的全文检索功能可以使用 tsvector 和 tsquery 数据类型,通过分词和倒排索引来实现语义搜索的基本需求。
Zhparser 分词插件:对于中文文本,可以使用 PostgreSQL 的 Zhparser 插件进行中文分词,结合全文检索功能实现语义搜索。

二、SpringBoot 结合PGVector

PGVector.java

系统里面使用SpringBoot, MyBatisPlus的方式来,目前MyBatisPlus本身并不支持PGVector。需要增加一个PGVector的方式,
PGVector有项目https://github.com/pgvector/pgvector-java.git, 实现过程中参考该项目来实现PGVector;(也可以参考SpringAI项目对PGVector支持)

package com.md.gpt.vector;

import org.postgresql.PGConnection;
import org.postgresql.util.ByteConverter;
import org.postgresql.util.PGBinaryObject;
import org.postgresql.util.PGobject;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

public class PGvector extends PGobject implements PGBinaryObject, Serializable, Cloneable {
    private float[] vec;

    /**
     * Constructor
     */
    public PGvector() {
        type = "vector";
    }

    /**
     * Constructor
     *
     * @param v float array
     */
    public PGvector(float[] v) {
        this();
        vec = v;
    }

    /**
     * Constructor
     *
     * @param <T> number
     * @param v list of numbers
     */
    public <T extends Number> PGvector(List<T> v) {
        this();
        if (Objects.isNull(v)) {
            vec = null;
        } else {
            vec = new float[v.size()];
            int i = 0;
            for (T f : v) {
                vec[i++] = f.floatValue();
            }
        }
    }

    /**
     * Constructor
     *
     * @param s text representation of a vector
     * @throws SQLException exception
     */
    public PGvector(String s) throws SQLException {
        this();
        setValue(s);
    }

    /**
     * Sets the value from a text representation of a vector
     */
    @Override
    public void setValue(String s) throws SQLException {
        if (s == null) {
            vec = null;
        } else {
            String[] sp = s.substring(1, s.length() - 1).split(",");
            vec = new float[sp.length];
            for (int i = 0; i < sp.length; i++) {
                vec[i] = Float.parseFloat(sp[i]);
            }
        }
    }

    /**
     * Returns the text representation of a vector
     */
    @Override
    public String getValue() {
        if (vec == null) {
            return null;
        } else {
            return Arrays.toString(vec).replace(" ", "");
        }
    }

    /**
     * Returns the number of bytes for the binary representation
     */
    @Override
    public int lengthInBytes() {
        return vec == null ? 0 : 4 + vec.length * 4;
    }

    /**
     * Sets the value from a binary representation of a vector
     */
    @Override
    public void setByteValue(byte[] value, int offset) throws SQLException {
        int dim = ByteConverter.int2(value, offset);

        int unused = ByteConverter.int2(value, offset + 2);
        if (unused != 0) {
            throw new SQLException("expected unused to be 0");
        }

        vec = new float[dim];
        for (int i = 0; i < dim; i++) {
            vec[i] = ByteConverter.float4(value, offset + 4 + i * 4);
        }
    }

    /**
     * Writes the binary representation of a vector
     */
    @Override
    public void toBytes(byte[] bytes, int offset) {
        if (vec == null) {
            return;
        }

        // server will error on overflow due to unconsumed buffer
        // could set to Short.MAX_VALUE for friendlier error message
        ByteConverter.int2(bytes, offset, vec.length);
        ByteConverter.int2(bytes, offset + 2, 0);
        for (int i = 0; i < vec.length; i++) {
            ByteConverter.float4(bytes, offset + 4 + i * 4, vec[i]);
        }
    }

    /**
     * Returns an array
     *
     * @return an array
     */
    public float[] toArray() {
        return vec;
    }

    /**
     * Registers the vector type
     *
     * @param conn connection
     * @throws SQLException exception
     */
    public static void addVectorType(Connection conn) throws SQLException {
        conn.unwrap(PGConnection.class).addDataType("vector", PGvector.class);
    }
}

PGvectorTypeHandler.java

PGvectorTypeHandler 类主要如下

import com.md.gpt.util.FloatUtil;
import com.md.gpt.vector.PGvector;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedJdbcTypes;
import org.apache.ibatis.type.MappedTypes;

import java.sql.*;
import java.util.Arrays;

import static org.springframework.util.ObjectUtils.toObjectArray;

@Slf4j
@MappedTypes({PGvector.class})
public class PGvectorTypeHandler extends BaseTypeHandler<PGvector> {

    @Override
    public void setNonNullParameter(PreparedStatement ps, int i, PGvector parameter, JdbcType jdbcType) throws SQLException {
        log.info("Getting PGvector result by column name: {}", parameter.toString());
        Connection conn = ps.getConnection();
        Float[] boxedArray = FloatUtil.toObjectArray(parameter.toArray());
        Array sqlArray = conn.createArrayOf("float", boxedArray);
        ps.setArray(i, sqlArray);
    }

    @Override
    public PGvector getNullableResult(ResultSet rs, String columnName) throws SQLException {
        log.info("Getting PGvector result by column name: {}", columnName);
        Array array = rs.getArray(columnName);
        if (array != null) {
            Float[] javaArray = (Float[])array.getArray();
            return new PGvector(FloatUtil.toPrimitiveArray(javaArray));
        }
        return null;
    }

    @Override
    public PGvector getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
        log.info("Getting PGvector result by column index: {}", columnIndex);
        return (PGvector) rs.getObject(columnIndex);
    }

    @Override
    public PGvector getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
        log.info("Getting PGvector result by column index: {}", columnIndex);
        return (PGvector) cs.getObject(columnIndex);
    }
}

MyBatisConfig.java对应修改

 @Bean
    public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
        log.info("Registering PGvectorTypeHandler in sqlSessionFactory");

        MybatisSqlSessionFactoryBean sqlSessionFactoryBean = new MybatisSqlSessionFactoryBean();
        sqlSessionFactoryBean.setDataSource(dataSource);

        // 注册 MyBatis Plus 拦截器
        sqlSessionFactoryBean.setPlugins(mybatisPlusInterceptor());

        // Set location of Mapper XML files
        sqlSessionFactoryBean.setMapperLocations(
                new PathMatchingResourcePatternResolver().getResources("classpath:/mapper/*.xml")
        );


        // 创建 MybatisConfiguration 实例并配置
        MybatisConfiguration configuration = new MybatisConfiguration();
        configuration.addMappers("com.md.gpt.mapper");
//        configuration.addMapperLocation("classpath:mapper/*.xml");
        // 例如,注册自定义类型处理器
        configuration.getTypeHandlerRegistry().register(PGvectorTypeHandler.class);

        sqlSessionFactoryBean.setConfiguration(configuration);
        return sqlSessionFactoryBean.getObject();
    }

    @Bean
    public TypeHandlerRegistry typeHandlerRegistry() {
        log.info("Registering PGvectorTypeHandler in TypeHandlerRegistry");

        TypeHandlerRegistry registry = new TypeHandlerRegistry();
        // 注册自定义类型处理器
        registry.register(PGvectorTypeHandler.class);
        return registry;
    }

PGvectorConfiguration.java

import com.md.gpt.mybatis.PGvectorTypeHandler;
import com.md.gpt.vector.PGvector;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.CommandLineRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;

@Configuration
@Slf4j
public class PGvectorConfiguration {

    @Bean
    CommandLineRunner registerPGvectorType(DataSource dataSource) {
        return args -> {
            try (Connection conn = dataSource.getConnection()) {
                log.info("Registering PGvector type with the database");
                PGvector.addVectorType(conn); // This registers the 'vector' type
            } catch (SQLException e) {
                throw new RuntimeException("Failed to register PGvector type with the database", e);
            }
        };
    }
}

FaguikuMapper.java

List<Faguiku> findNewClosestEmbeddings(@Param("embedding") PGvector embedding, String yearInfo,
                                       String title, String keywords, String region, String docNumber,
                                       String startDate, String endDate, String status, String fileType,
                                       String docDanwei, String fileTaxType,int pageSize, int offset);

FaguikuMapper.xml

<select id="findNewClosestEmbeddings" resultMap="BaseResultMap">
    SELECT  id, title, doc_number, date_written, status, link, attachment_link, file_type, doc_danwei, file_tax_type, status_sort_order
    FROM faguiku
    <where>
        <if test="yearInfo != null and yearInfo.trim() != ''">
            title LIKE CONCAT('%', #{yearInfo}, '%') OR doc_number LIKE CONCAT('%', #{yearInfo}, '%')
        </if>
        <if test="title != null and title.trim() != ''">
            AND to_tsvector('zh_cn', title) @@ plainto_tsquery('zh_cn', #{title})
        </if>
        <if test="keywords != null and keywords.trim() != ''">
            AND  to_tsvector('zh_cn', content2) @@ plainto_tsquery('zh_cn', #{keywords})
        </if>
        <if test="region != null and region.trim() != ''">
            AND  REPLACE(REPLACE(source, '省', ''), '市', '') = REPLACE(REPLACE(#{region}, '省', ''), '市', '')
        </if>
        <if test="docNumber != null and docNumber.trim() != ''">
            AND  doc_number LIKE CONCAT('%', #{docNumber}, '%')
        </if>
        <if test="status != null and status.trim() != ''">
            AND status = #{status}
        </if>
        <if test="docDanwei != null and docDanwei.trim() != ''">
            AND doc_danwei LIKE CONCAT('%', #{docDanwei}, '%')
        </if>
        <if test="fileTaxType != null and fileTaxType.trim() != ''">
            AND file_tax_type = #{fileTaxType}
        </if>
        <if test="startDate != null and startDate.trim() != ''">
            AND date_written &gt;= #{startDate}
        </if>
        <if test="endDate != null and endDate.trim() != ''">
            AND date_written &lt;= #{endDate}
        </if>
    </where>
    ORDER BY embeddings::vector    <![CDATA[ <=>  ]]>  #{embedding}::vector, status_sort_order, date_written desc
    LIMIT #{pageSize} OFFSET #{offset}
</select>

embeddings 通过Embedding API去获取,可以通过各个大模型的Embedding API,也可以通过部署Embeddings API, 根据huggingface中的排名来选择中文支持较好的Embedding模型


总结

这种方案,较好的考虑到了查询的方便性,对需要管理的文档,统一在数据库中管理;否则如果有成千上万篇文档,而不能有效的通过系统管理起来,那几乎很难维护了。通过Hybrid Search,结合传统数据库对于一些字段的完全匹配,结合全文搜索和向量搜索,得出来的结果,可以根据查询结果来调整,把符合条件的都过滤出来给用户做下一步分析使用。
有几个待进一步探讨探讨的问题:
1.文本分段,文章内容来源于很多地方,通过按照chunksize overlap来分段,然后再embedding, 有些自然段落被划分错了。通过调用大模型来划分,Token消耗比较多,时间比较久。目前使用的库里面,用到了3000万Token。而且有些划分还是不太合理,需要人工介入才能划分的合理。
2. Rerank模型,因为排序在sql里面做了,所以没有用rerank模型,准备按照rerank模型的算法看看能不能本地部署或者实现,来看看rerank能提高多少
3. 查询的文档比较多,最后给大模型来整理的话,一个是传入的长度有限制,不能把所有结果都传过去;二是太长的查询,消耗的也多。
4. 其他RAG方面的最新论文探索,比如Graph RAG等,后续进一步研究
5. 其他问题,欢迎加微信交流 rogerlzp

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值