import re
import pickle
import faiss
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch
import mysql.connector
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
def connect_mysql_getdata():
config = {
'user': 'root',
'password': '123456',
'host': '127.0.0.1',
'port': '3306',
'database': 'vector'
}
try:
connection = mysql.connector.connect(**config)
if connection.is_connected():
print("成功连接到MySQL")
cursor = connection.cursor()
query = """ SELECT ques_id, detail
FROM question_classify_total qct
WHERE detail IS NOT NULL
GROUP BY ques_id, detail
"""
cursor.execute(query)
result = cursor.fetchall()
result = pd.DataFrame(result, columns=['beta_no', 'description'])
return result
except mysql.connector.Error as e:
print(f"数据库连接或查询失败: {e}")
return None
finally:
if 'connection' in locals() and connection.is_connected():
cursor.close()
connection.close()
print("MySQL 连接已关闭")
bertbasepath = 'D:\\Project_Git\\bert-base-chinese' if os.name == 'nt' else '/data/bert-base-chinese/'
tokenizer = BertTokenizer.from_pretrained(bertbasepath)
model = BertModel.from_pretrained(bertbasepath)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def vectorize_texts(texts, batch_size=32):
vectors = []
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
encoded_input = tokenizer(batch_texts, padding=True, truncation=True, max_length=128, return_tensors='pt')
input_ids = encoded_input['input_ids'].to(device)
attention_mask = encoded_input['attention_mask'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
vectors.append(cls_embeddings)
return np.vstack(vectors)
def preprocess_text(text):
text = re.sub(r'[^A-Za-z0-9\u4e00-\u9fa5 ]+', '', text)
return text.lower()
def generate_faiss_index():
df = connect_mysql_getdata()
if df is not None and not df.empty:
df['description'] = df['description'].astype(str).apply(preprocess_text)
descriptions = df['description'].tolist()
vectors = vectorize_texts(descriptions, batch_size=32)
faiss.normalize_L2(vectors)
index = faiss.IndexFlatIP(768)
index.add(vectors)
faiss.write_index(index, 'description_vectors.faiss')
with open('id_mapping.pkl', 'wb') as f:
pickle.dump(df['beta_no'].tolist(), f)
print("向量索引和映射已保存。")
else:
print("没有数据可供索引。")
if __name__ == "__main__":
generate_faiss_index()
import re
import pickle
import faiss
import numpy as np
import torch
from transformers import BertTokenizer, BertModel
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
bertbasepath = 'D:\\Project_Git\\bert-base-chinese' if os.name == 'nt' else '/data/bert-base-chinese/'
tokenizer = BertTokenizer.from_pretrained(bertbasepath)
model = BertModel.from_pretrained(bertbasepath)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def vectorize_texts(texts, batch_size=32):
vectors = []
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
encoded_input = tokenizer(batch_texts, padding=True, truncation=True, max_length=128, return_tensors='pt')
input_ids = encoded_input['input_ids'].to(device)
attention_mask = encoded_input['attention_mask'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
vectors.append(cls_embeddings)
return np.vstack(vectors)
def preprocess_text(text):
text = re.sub(r'[^A-Za-z0-9\u4e00-\u9fa5 ]+', '', text)
return text.lower()
def load_index_and_mapping():
index = faiss.read_index('description_vectors.faiss')
with open('id_mapping.pkl', 'rb') as f:
id_mapping = pickle.load(f)
return index, id_mapping
def find_top_k_similar(query, index, id_mapping, top_k=3):
query = preprocess_text(query)
vectors = vectorize_texts([query], batch_size=1)
faiss.normalize_L2(vectors)
D, I = index.search(vectors, top_k)
results = []
for idx, score in zip(I[0], D[0]):
if idx < len(id_mapping):
beta_no = id_mapping[idx]
results.append({'beta_no': beta_no, 'score': score})
return results
if __name__ == "__main__":
index, id_mapping = load_index_and_mapping()
query = "光合作用的原理是什么"
top_k = 3
similar_items = find_top_k_similar(query, index, id_mapping, top_k)
if similar_items:
print(f"Top {top_k} 相似结果:")
for item in similar_items:
print(f"beta_no: {item['beta_no']}, 相似度: {item['score']:.4f}")
else:
print("未找到相似结果。")
import re
import pickle
import faiss
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch
import mysql.connector
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
def connect_mysql_getdata():
config = {
'user': 'root',
'password': '123456',
'host': '127.0.0.1',
'port': '3306',
'database': 'vector'
}
try:
connection = mysql.connector.connect(**config)
if connection.is_connected():
print("成功连接到MySQL")
cursor = connection.cursor()
query = """ SELECT ques_id, detail
FROM question_classify_total qct
WHERE detail IS NOT NULL
GROUP BY ques_id, detail
"""
cursor.execute(query)
result = cursor.fetchall()
result = pd.DataFrame(result, columns=['beta_no', 'description'])
return result
except mysql.connector.Error as e:
print(f"数据库连接或查询失败: {e}")
return None
finally:
if 'connection' in locals() and connection.is_connected():
cursor.close()
connection.close()
print("MySQL 连接已关闭")
bertbasepath = 'D:\\Project_Git\\bert-base-chinese' if os.name == 'nt' else '/data/bert-base-chinese/'
tokenizer = BertTokenizer.from_pretrained(bertbasepath)
model = BertModel.from_pretrained(bertbasepath)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def vectorize_texts(texts, batch_size=32):
vectors = []
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
encoded_input = tokenizer(batch_texts, padding=True, truncation=True, max_length=128, return_tensors='pt')
input_ids = encoded_input['input_ids'].to(device)
attention_mask = encoded_input['attention_mask'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
vectors.append(cls_embeddings)
return np.vstack(vectors)
def preprocess_text(text):
text = re.sub(r'[^A-Za-z0-9\u4e00-\u9fa5 ]+', '', text)
return text.lower()
def load_index_and_mapping(index_file='description_vectors.faiss', mapping_file='id_mapping.pkl'):
if os.path.exists(index_file) and os.path.exists(mapping_file):
index = faiss.read_index(index_file)
with open(mapping_file, 'rb') as f:
id_mapping = pickle.load(f)
return index, id_mapping
else:
print("索引文件或映射文件不存在,无法加载。")
return None, None
def incremental_update():
index, id_mapping = load_index_and_mapping()
if index is None or id_mapping is None:
print("索引或映射加载失败,无法进行增量更新。")
return
new_data = connect_mysql_getdata()
if new_data is None or new_data.empty:
print("没有新数据可供更新。")
return
new_data = new_data[~new_data['beta_no'].isin(id_mapping)]
if new_data.empty:
print("无增量数据,索引未更新。")
return
new_data['description'] = new_data['description'].astype(str).apply(preprocess_text)
new_descriptions = new_data['description'].tolist()
new_vectors = vectorize_texts(new_descriptions, batch_size=32)
faiss.normalize_L2(new_vectors)
index.add(new_vectors)
new_ids = new_data['beta_no'].tolist()
id_mapping.extend(new_ids)
faiss.write_index(index, 'description_vectors.faiss')
with open('id_mapping.pkl', 'wb') as f:
pickle.dump(id_mapping, f)
print(f"索引已增量更新,新增 {len(new_ids)} 条记录。")
def full_update():
if os.path.exists('description_vectors.faiss'):
os.remove('description_vectors.faiss')
if os.path.exists('id_mapping.pkl'):
os.remove('id_mapping.pkl')
print("已删除旧的索引文件,准备全量更新。")
full_data = connect_mysql_getdata()
if full_data is None or full_data.empty:
print("没有数据可供全量更新。")
return
full_data['description'] = full_data['description'].astype(str).apply(preprocess_text)
descriptions = full_data['description'].tolist()
vectors = vectorize_texts(descriptions, batch_size=32)
faiss.normalize_L2(vectors)
index = faiss.IndexFlatIP(768)
index.add(vectors)
faiss.write_index(index, 'description_vectors.faiss')
with open('id_mapping.pkl', 'wb') as f:
pickle.dump(full_data['beta_no'].tolist(), f)
print(f"索引已全量更新,共 {len(full_data)} 条记录。")
if __name__ == "__main__":
print("选择操作模式:")
print("1. 增量更新索引")
print("2. 全量更新索引")
choice = input("请输入操作选项(1或2):")
if choice == "1":
incremental_update()
elif choice == "2":
full_update()
else:
print("无效的输入,请选择1或2。")
CREATE TABLE `question_classify_total` (
`ques_id` int NOT NULL,
`detail` varchar(255) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NULL DEFAULT NULL,
PRIMARY KEY (`ques_id`) USING BTREE
) ENGINE = InnoDB CHARACTER SET = utf8mb3 COLLATE = utf8mb3_general_ci ROW_FORMAT = Dynamic;
INSERT INTO `question_classify_total` VALUES (1, '法国的首都是哪里?');
INSERT INTO `question_classify_total` VALUES (2, '光合作用的原理是什么?');
INSERT INTO `question_classify_total` VALUES (3, '请解释相对论。');
INSERT INTO `question_classify_total` VALUES (4, '全球变暖有什么影响?');
INSERT INTO `question_classify_total` VALUES (5, '水的净化过程是怎样的?');
INSERT INTO `question_classify_total` VALUES (6, '均衡饮食的好处有哪些?');
INSERT INTO `question_classify_total` VALUES (7, '疫苗是如何预防疾病的?');
INSERT INTO `question_classify_total` VALUES (8, '社交媒体对社会有什么影响?');
INSERT INTO `question_classify_total` VALUES (9, '什么是人工智能?');
INSERT INTO `question_classify_total` VALUES (10, '可再生能源和不可再生能源的区别是什么?');
INSERT INTO `question_classify_total` VALUES (11, '法国最大的城市是哪座?');
INSERT INTO `question_classify_total` VALUES (12, '请描述光合作用的过程。');
INSERT INTO `question_classify_total` VALUES (13, '什么是爱因斯坦的相对论?');
INSERT INTO `question_classify_total` VALUES (14, '气候变化对地球的影响是什么?');
INSERT INTO `question_classify_total` VALUES (15, '水净化系统是如何工作的?');
INSERT INTO `question_classify_total` VALUES (16, '为什么均衡饮食很重要?');
INSERT INTO `question_classify_total` VALUES (17, '疫苗如何帮助免疫系统?');
INSERT INTO `question_classify_total` VALUES (18, '社交媒体如何影响社会?');
INSERT INTO `question_classify_total` VALUES (19, '什么是机器学习,它与人工智能有何关系?');
INSERT INTO `question_classify_total` VALUES (20, '可再生能源与不可再生能源如何对比?');
INSERT INTO `question_classify_total` VALUES (21, '计算机如何进行加密和解密?');
INSERT INTO `question_classify_total` VALUES (22, '区块链技术的工作原理是什么?');
INSERT INTO `question_classify_total` VALUES (23, '为什么太阳能是一种可持续能源?');
INSERT INTO `question_classify_total` VALUES (24, '介绍牛顿的三大运动定律。');
INSERT INTO `question_classify_total` VALUES (25, '如何解释宇宙大爆炸理论?');
INSERT INTO `question_classify_total` VALUES (26, '为什么地震会发生?');
INSERT INTO `question_classify_total` VALUES (27, '如何处理塑料污染问题?');
INSERT INTO `question_classify_total` VALUES (28, '人工智能如何改变医疗行业?');
INSERT INTO `question_classify_total` VALUES (29, '电动汽车的优势和劣势是什么?');
INSERT INTO `question_classify_total` VALUES (30, '光学纤维是如何传输数据的?');
INSERT INTO `question_classify_total` VALUES (31, '国际空间站是如何运行的?');
INSERT INTO `question_classify_total` VALUES (32, '如何提高网络安全?');
INSERT INTO `question_classify_total` VALUES (33, '什么是深度学习?');
INSERT INTO `question_classify_total` VALUES (34, '纳米技术如何应用于医学?');
INSERT INTO `question_classify_total` VALUES (35, '气候变化的根本原因是什么?');
INSERT INTO `question_classify_total` VALUES (36, '如何减少温室气体排放?');
INSERT INTO `question_classify_total` VALUES (37, '机器人技术如何改变制造业?');
INSERT INTO `question_classify_total` VALUES (38, '什么是量子计算?');
INSERT INTO `question_classify_total` VALUES (39, '如何储存太阳能?');
INSERT INTO `question_classify_total` VALUES (40, '物联网如何改变城市生活?');
INSERT INTO `question_classify_total` VALUES (41, '法国最大的城市是哪座?');
INSERT INTO `question_classify_total` VALUES (42, '请描述光合作用的过程。');
INSERT INTO `question_classify_total` VALUES (43, '什么是爱因斯坦的相对论?');
INSERT INTO `question_classify_total` VALUES (44, '气候变化对地球的影响是什么?');
INSERT INTO `question_classify_total` VALUES (45, '水净化系统是如何工作的?');
INSERT INTO `question_classify_total` VALUES (46, '为什么均衡饮食很重要?');
INSERT INTO `question_classify_total` VALUES (47, '疫苗如何帮助免疫系统?');
INSERT INTO `question_classify_total` VALUES (48, '社交媒体如何影响社会?');
INSERT INTO `question_classify_total` VALUES (49, '什么是机器学习,它与人工智能有何关系?');
INSERT INTO `question_classify_total` VALUES (50, '可再生能源与不可再生能源如何对比?');