数据向量化保存并检索

3 篇文章 0 订阅
# generate_index.py

import re
import pickle
import faiss
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch
import mysql.connector
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 1. 连接数据库并获取数据
def connect_mysql_getdata():
    config = {
        'user': 'root',  # 数据库用户名
        'password': '123456',  # 数据库密码
        'host': '127.0.0.1',  # 数据库主机地址域名
        'port': '3306',  # 数据库端口号
        'database': 'vector'  # 数据库名称
    }

    try:
        connection = mysql.connector.connect(**config)
        if connection.is_connected():
            print("成功连接到MySQL")
            cursor = connection.cursor()

            query = """ SELECT ques_id, detail       
                        FROM question_classify_total qct
                        WHERE detail IS NOT NULL
                        GROUP BY ques_id, detail 
                    """

            cursor.execute(query)
            result = cursor.fetchall()

            result = pd.DataFrame(result, columns=['beta_no', 'description'])
            return result

    except mysql.connector.Error as e:
        print(f"数据库连接或查询失败: {e}")
        return None

    finally:
        if 'connection' in locals() and connection.is_connected():
            cursor.close()
            connection.close()
            print("MySQL 连接已关闭")


# 2. 初始化BERT模型和分词器
bertbasepath = 'D:\\Project_Git\\bert-base-chinese' if os.name == 'nt' else '/data/bert-base-chinese/'
tokenizer = BertTokenizer.from_pretrained(bertbasepath)
model = BertModel.from_pretrained(bertbasepath)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 3. 定义向量化函数
def vectorize_texts(texts, batch_size=32):
    vectors = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            encoded_input = tokenizer(batch_texts, padding=True, truncation=True, max_length=128, return_tensors='pt')
            input_ids = encoded_input['input_ids'].to(device)
            attention_mask = encoded_input['attention_mask'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            vectors.append(cls_embeddings)
    return np.vstack(vectors)

# 4. 数据预处理函数
def preprocess_text(text):
    text = re.sub(r'[^A-Za-z0-9\u4e00-\u9fa5 ]+', '', text)
    return text.lower()

# 5. 生成向量和索引
def generate_faiss_index():
    df = connect_mysql_getdata()

    if df is not None and not df.empty:
        df['description'] = df['description'].astype(str).apply(preprocess_text)

        descriptions = df['description'].tolist()
        vectors = vectorize_texts(descriptions, batch_size=32)

        faiss.normalize_L2(vectors)
        index = faiss.IndexFlatIP(768)
        index.add(vectors)

        faiss.write_index(index, 'description_vectors.faiss')
        with open('id_mapping.pkl', 'wb') as f:
            pickle.dump(df['beta_no'].tolist(), f)

        print("向量索引和映射已保存。")
    else:
        print("没有数据可供索引。")


if __name__ == "__main__":
    generate_faiss_index()
# search_index.py

import re
import pickle
import faiss
import numpy as np
import torch
from transformers import BertTokenizer, BertModel
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 1. 初始化BERT模型和分词器
bertbasepath = 'D:\\Project_Git\\bert-base-chinese' if os.name == 'nt' else '/data/bert-base-chinese/'
tokenizer = BertTokenizer.from_pretrained(bertbasepath)
model = BertModel.from_pretrained(bertbasepath)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


# 2. 定义向量化函数
def vectorize_texts(texts, batch_size=32):
    vectors = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            encoded_input = tokenizer(batch_texts, padding=True, truncation=True, max_length=128, return_tensors='pt')
            input_ids = encoded_input['input_ids'].to(device)
            attention_mask = encoded_input['attention_mask'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            vectors.append(cls_embeddings)
    return np.vstack(vectors)


# 3. 数据预处理函数
def preprocess_text(text):
    text = re.sub(r'[^A-Za-z0-9\u4e00-\u9fa5 ]+', '', text)
    return text.lower()


# 4. 加载索引和映射
def load_index_and_mapping():
    index = faiss.read_index('description_vectors.faiss')
    with open('id_mapping.pkl', 'rb') as f:
        id_mapping = pickle.load(f)
    return index, id_mapping


# 5. 查找相似文本
def find_top_k_similar(query, index, id_mapping, top_k=3):
    query = preprocess_text(query)
    vectors = vectorize_texts([query], batch_size=1)
    faiss.normalize_L2(vectors)
    D, I = index.search(vectors, top_k)

    results = []
    for idx, score in zip(I[0], D[0]):
        if idx < len(id_mapping):
            beta_no = id_mapping[idx]
            results.append({'beta_no': beta_no, 'score': score})
    return results


# 6. 查询主程序
if __name__ == "__main__":
    index, id_mapping = load_index_and_mapping()

    query = "光合作用的原理是什么"
    top_k = 3
    similar_items = find_top_k_similar(query, index, id_mapping, top_k)

    if similar_items:
        print(f"Top {top_k} 相似结果:")
        for item in similar_items:
            print(f"beta_no: {item['beta_no']}, 相似度: {item['score']:.4f}")
    else:
        print("未找到相似结果。")

# update_index.py

import re
import pickle
import faiss
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch
import mysql.connector
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


# 1. 连接数据库并获取数据
def connect_mysql_getdata():
    config = {
        'user': 'root',  # 数据库用户名
        'password': '123456',  # 数据库密码
        'host': '127.0.0.1',  # 数据库主机地址域名
        'port': '3306',  # 数据库端口号
        'database': 'vector'  # 数据库名称
    }

    try:
        connection = mysql.connector.connect(**config)
        if connection.is_connected():
            print("成功连接到MySQL")
            cursor = connection.cursor()

            query = """ SELECT ques_id, detail       
                        FROM question_classify_total qct
                        WHERE detail IS NOT NULL
                        GROUP BY ques_id, detail 
                    """

            cursor.execute(query)
            result = cursor.fetchall()

            result = pd.DataFrame(result, columns=['beta_no', 'description'])
            return result

    except mysql.connector.Error as e:
        print(f"数据库连接或查询失败: {e}")
        return None

    finally:
        if 'connection' in locals() and connection.is_connected():
            cursor.close()
            connection.close()
            print("MySQL 连接已关闭")


# 2. 初始化BERT模型和分词器
bertbasepath = 'D:\\Project_Git\\bert-base-chinese' if os.name == 'nt' else '/data/bert-base-chinese/'
tokenizer = BertTokenizer.from_pretrained(bertbasepath)
model = BertModel.from_pretrained(bertbasepath)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


# 3. 定义向量化函数
def vectorize_texts(texts, batch_size=32):
    vectors = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            encoded_input = tokenizer(batch_texts, padding=True, truncation=True, max_length=128, return_tensors='pt')
            input_ids = encoded_input['input_ids'].to(device)
            attention_mask = encoded_input['attention_mask'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            vectors.append(cls_embeddings)
    return np.vstack(vectors)


# 4. 数据预处理函数
def preprocess_text(text):
    text = re.sub(r'[^A-Za-z0-9\u4e00-\u9fa5 ]+', '', text)
    return text.lower()


# 5. 加载索引和映射文件
def load_index_and_mapping(index_file='description_vectors.faiss', mapping_file='id_mapping.pkl'):
    if os.path.exists(index_file) and os.path.exists(mapping_file):
        index = faiss.read_index(index_file)
        with open(mapping_file, 'rb') as f:
            id_mapping = pickle.load(f)
        return index, id_mapping
    else:
        print("索引文件或映射文件不存在,无法加载。")
        return None, None


# 6. 增量修改索引
def incremental_update():
    # 加载现有的索引和映射文件
    index, id_mapping = load_index_and_mapping()

    if index is None or id_mapping is None:
        print("索引或映射加载失败,无法进行增量更新。")
        return

    # 获取新数据
    new_data = connect_mysql_getdata()
    if new_data is None or new_data.empty:
        print("没有新数据可供更新。")
        return

    # 查找是否有增量数据
    new_data = new_data[~new_data['beta_no'].isin(id_mapping)]
    if new_data.empty:
        print("无增量数据,索引未更新。")
        return

    # 数据预处理
    new_data['description'] = new_data['description'].astype(str).apply(preprocess_text)

    # 向量化新数据
    new_descriptions = new_data['description'].tolist()
    new_vectors = vectorize_texts(new_descriptions, batch_size=32)

    # 归一化向量
    faiss.normalize_L2(new_vectors)

    # 添加新向量到索引
    index.add(new_vectors)

    # 更新id映射
    new_ids = new_data['beta_no'].tolist()
    id_mapping.extend(new_ids)

    # 保存更新后的索引和映射
    faiss.write_index(index, 'description_vectors.faiss')
    with open('id_mapping.pkl', 'wb') as f:
        pickle.dump(id_mapping, f)

    print(f"索引已增量更新,新增 {len(new_ids)} 条记录。")


# 7. 全量修改索引
def full_update():
    # 清空旧的索引和映射文件
    if os.path.exists('description_vectors.faiss'):
        os.remove('description_vectors.faiss')
    if os.path.exists('id_mapping.pkl'):
        os.remove('id_mapping.pkl')

    print("已删除旧的索引文件,准备全量更新。")

    # 获取全量数据
    full_data = connect_mysql_getdata()

    if full_data is None or full_data.empty:
        print("没有数据可供全量更新。")
        return

    # 数据预处理
    full_data['description'] = full_data['description'].astype(str).apply(preprocess_text)

    # 向量化全量数据
    descriptions = full_data['description'].tolist()
    vectors = vectorize_texts(descriptions, batch_size=32)

    # 归一化向量
    faiss.normalize_L2(vectors)

    # 创建新的索引
    index = faiss.IndexFlatIP(768)
    index.add(vectors)

    # 保存全量更新的索引和映射
    faiss.write_index(index, 'description_vectors.faiss')
    with open('id_mapping.pkl', 'wb') as f:
        pickle.dump(full_data['beta_no'].tolist(), f)

    print(f"索引已全量更新,共 {len(full_data)} 条记录。")


# 8. 更新主程序
if __name__ == "__main__":
    print("选择操作模式:")
    print("1. 增量更新索引")
    print("2. 全量更新索引")

    choice = input("请输入操作选项(1或2):")

    if choice == "1":
        incremental_update()
    elif choice == "2":
        full_update()
    else:
        print("无效的输入,请选择1或2。")

CREATE TABLE `question_classify_total`  (
  `ques_id` int NOT NULL,
  `detail` varchar(255) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NULL DEFAULT NULL,
  PRIMARY KEY (`ques_id`) USING BTREE
) ENGINE = InnoDB CHARACTER SET = utf8mb3 COLLATE = utf8mb3_general_ci ROW_FORMAT = Dynamic;
INSERT INTO `question_classify_total` VALUES (1, '法国的首都是哪里?');
INSERT INTO `question_classify_total` VALUES (2, '光合作用的原理是什么?');
INSERT INTO `question_classify_total` VALUES (3, '请解释相对论。');
INSERT INTO `question_classify_total` VALUES (4, '全球变暖有什么影响?');
INSERT INTO `question_classify_total` VALUES (5, '水的净化过程是怎样的?');
INSERT INTO `question_classify_total` VALUES (6, '均衡饮食的好处有哪些?');
INSERT INTO `question_classify_total` VALUES (7, '疫苗是如何预防疾病的?');
INSERT INTO `question_classify_total` VALUES (8, '社交媒体对社会有什么影响?');
INSERT INTO `question_classify_total` VALUES (9, '什么是人工智能?');
INSERT INTO `question_classify_total` VALUES (10, '可再生能源和不可再生能源的区别是什么?');
INSERT INTO `question_classify_total` VALUES (11, '法国最大的城市是哪座?');
INSERT INTO `question_classify_total` VALUES (12, '请描述光合作用的过程。');
INSERT INTO `question_classify_total` VALUES (13, '什么是爱因斯坦的相对论?');
INSERT INTO `question_classify_total` VALUES (14, '气候变化对地球的影响是什么?');
INSERT INTO `question_classify_total` VALUES (15, '水净化系统是如何工作的?');
INSERT INTO `question_classify_total` VALUES (16, '为什么均衡饮食很重要?');
INSERT INTO `question_classify_total` VALUES (17, '疫苗如何帮助免疫系统?');
INSERT INTO `question_classify_total` VALUES (18, '社交媒体如何影响社会?');
INSERT INTO `question_classify_total` VALUES (19, '什么是机器学习,它与人工智能有何关系?');
INSERT INTO `question_classify_total` VALUES (20, '可再生能源与不可再生能源如何对比?');
INSERT INTO `question_classify_total` VALUES (21, '计算机如何进行加密和解密?');
INSERT INTO `question_classify_total` VALUES (22, '区块链技术的工作原理是什么?');
INSERT INTO `question_classify_total` VALUES (23, '为什么太阳能是一种可持续能源?');
INSERT INTO `question_classify_total` VALUES (24, '介绍牛顿的三大运动定律。');
INSERT INTO `question_classify_total` VALUES (25, '如何解释宇宙大爆炸理论?');
INSERT INTO `question_classify_total` VALUES (26, '为什么地震会发生?');
INSERT INTO `question_classify_total` VALUES (27, '如何处理塑料污染问题?');
INSERT INTO `question_classify_total` VALUES (28, '人工智能如何改变医疗行业?');
INSERT INTO `question_classify_total` VALUES (29, '电动汽车的优势和劣势是什么?');
INSERT INTO `question_classify_total` VALUES (30, '光学纤维是如何传输数据的?');
INSERT INTO `question_classify_total` VALUES (31, '国际空间站是如何运行的?');
INSERT INTO `question_classify_total` VALUES (32, '如何提高网络安全?');
INSERT INTO `question_classify_total` VALUES (33, '什么是深度学习?');
INSERT INTO `question_classify_total` VALUES (34, '纳米技术如何应用于医学?');
INSERT INTO `question_classify_total` VALUES (35, '气候变化的根本原因是什么?');
INSERT INTO `question_classify_total` VALUES (36, '如何减少温室气体排放?');
INSERT INTO `question_classify_total` VALUES (37, '机器人技术如何改变制造业?');
INSERT INTO `question_classify_total` VALUES (38, '什么是量子计算?');
INSERT INTO `question_classify_total` VALUES (39, '如何储存太阳能?');
INSERT INTO `question_classify_total` VALUES (40, '物联网如何改变城市生活?');
INSERT INTO `question_classify_total` VALUES (41, '法国最大的城市是哪座?');
INSERT INTO `question_classify_total` VALUES (42, '请描述光合作用的过程。');
INSERT INTO `question_classify_total` VALUES (43, '什么是爱因斯坦的相对论?');
INSERT INTO `question_classify_total` VALUES (44, '气候变化对地球的影响是什么?');
INSERT INTO `question_classify_total` VALUES (45, '水净化系统是如何工作的?');
INSERT INTO `question_classify_total` VALUES (46, '为什么均衡饮食很重要?');
INSERT INTO `question_classify_total` VALUES (47, '疫苗如何帮助免疫系统?');
INSERT INTO `question_classify_total` VALUES (48, '社交媒体如何影响社会?');
INSERT INTO `question_classify_total` VALUES (49, '什么是机器学习,它与人工智能有何关系?');
INSERT INTO `question_classify_total` VALUES (50, '可再生能源与不可再生能源如何对比?');
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值