项目需求:我们需要对面试经验信息进行总结,但是面试经验的数据量非常大,直接传给大模型会面临以下问题
- 大模型输入长度有限:大部分的大模型输入长度为2048token,远远低于我们的文本量,即便是一篇博客的内容,也有可能超过上限
- 面对长文本的总结信息受限:由于我们用的是6B模型,并不如目前常见的大模型如gpt4、通义千问等效果好,输入文本越长,总结能力越差。
根据上述问题,我们可以先对面试经验中的句子进行聚类,将相似的句子分类在一起,再让大模型对类似的信息进行概括,这样总结能力可能会好一点。
实现
使用 transformers
框架加载本地的 bge-large-zh
模型,对输入文本进行嵌入,再使用 sklearn
中的聚类算法对句子向量聚类。
Embedding
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import torch
import torch.nn.functional as F
def embedding(sentences):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 加载模型
tokenizer = AutoTokenizer.from_pretrained('../BAAI/bge-large-zh')
model = AutoModel.from_pretrained('../BAAI/bge-large-zh').to(device)
# tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
# model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(device)
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
# Compute embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# 池化和归一化
sentence_embeddings = model_output[0][:, 0]
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).to(device)
return sentence_embeddings
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
X = embedding(questions)
结果
class 0 : 100w数据的hashmap扩容复制会有什么问题
class 0 : 回到100w,hashmap扩容复制迁移数据性能影响怎么解决
class 1 : sychorinized的作用,解决什么问题
class 1 : sychorinized作用在方法和对象上的区别
class 2 : 创新创业大赛的项目是什么,几个人写的
class 2 : 大学有做过哪些不是个人写的是协作写的项目
class 3 : hashmap的底层结构
class 3 : hashmap插入数据流程
class 3 : hashmap解决冲突方法
class 3 : hashmap数量特别多怎么解决插入和查找性能
class 3 : hashmap是怎么扩容的
class 3 : hashmap的桶值是怎么计算出的
class 3 : 手撕:双向链表插入+层序遍历二叉树
class 4 : CAS实现原理
class 4 : CAS会自旋多久
class 5 : Exception受检异常和非受检的区别
class 5 : 受检异常和非受检异常各在什么情况下使用
class 5 : 编程中如何用好两个异常
class 6 : ABA问题是什么
class 6 : 为什么会有ABA问题
class 7 : jvm的结构
class 8 : JVM内存泄漏是什么
class 8 : 什么情况会导致内存泄漏
class 9 : 版本号方案中版本号没有修改怎么办
class 9 : 函数重载为什么不用方法的返回值实现重载
class 9 : 怎么知道调用的重载的哪个方法
class 9 : 为什么要用redis不用mysql
class 10 : 二面
class 11 : 怎么学习一门新技术
class 11 : 怎么评价你自己是一个怎么样的人
class 12 : 还有没有用过其他并发同步技术
class 12 : 微博/微信的后台推送机制是怎么实现的
class 13 : 数据库有哪些索引
class 14 : 接口和抽象类的区别
class 14 : 还有没有其他区别
Process finished with exit code 0
聚类算法
k-means
k-means 算法的超参:簇的个数
from sklearn.cluster import KMeans, DBSCAN
from sklearn.cluster import AgglomerativeClustering
from transformers import AutoTokenizer, AutoModel
cluster_num = 15
# 初始化聚类模型
kmeans = KMeans(n_clusters=cluster_num)
# 进行聚类
kmeans.fit(X.cpu())
# 获取每个问题的类别
labels = kmeans.labels_
# 输出每个问题所属的类别
cluster_result = []
for i in range(cluster_num):
cluster_result.append([])
for i in range(len(questions2)):
# print("Question", i + 1, "belongs to cluster:", labels[i])
cluster_result[labels[i]].append(questions2[i])
for i in range(cluster_num):
for j in range(len(cluster_result[i])):
print("class " + str(i) + " : " + cluster_result[i][j])
dbscan
dbscan的超参是簇内实体的最大距离
dbscan = DBSCAN(eps=0.15, min_samples=2, metric='cosine')
dbscan.fit(X.cpu())
labels = dbscan.labels_
cluster_result = dict()
for i in range(len(questions2)):
if labels[i] not in cluster_result:
cluster_result[labels[i]] = []
cluster_result[labels[i]].append(i)
for label, sentences in cluster_result.items():
for i in sentences:
print("class " + str(label) + " : " + questions2[i])
结果
class 0 : sychorinized的作用,解决什么问题
class 0 : sychorinized作用在方法和对象上的区别
class -1 : 还有没有用过其他并发同步技术
class -1 : CAS实现原理
class -1 : CAS会自旋多久
class -1 : 版本号方案中版本号没有修改怎么办
class -1 : 编程中如何用好两个异常
class -1 : 接口和抽象类的区别
class -1 : 还有没有其他区别
class -1 : 手撕:双向链表插入+层序遍历二叉树
class -1 : 二面
class -1 : 创新创业大赛的项目是什么,几个人写的
class -1 : 大学有做过哪些不是个人写的是协作写的项目
class -1 : 微博/微信的后台推送机制是怎么实现的
class -1 : 为什么要用redis不用mysql
class -1 : 数据库有哪些索引
class -1 : 怎么学习一门新技术
class -1 : 怎么评价你自己是一个怎么样的人
class 1 : ABA问题是什么
class 1 : 为什么会有ABA问题
class 2 : Exception受检异常和非受检的区别
class 2 : 受检异常和非受检异常各在什么情况下使用
class 3 : JVM内存泄漏是什么
class 3 : 什么情况会导致内存泄漏
class 3 : jvm的结构
class 4 : 函数重载为什么不用方法的返回值实现重载
class 4 : 怎么知道调用的重载的哪个方法
class 5 : hashmap的底层结构
class 5 : hashmap插入数据流程
class 5 : hashmap解决冲突方法
class 5 : hashmap数量特别多怎么解决插入和查找性能
class 5 : 100w数据的hashmap扩容复制会有什么问题
class 5 : hashmap是怎么扩容的
class 5 : hashmap的桶值是怎么计算出的
class 5 : 回到100w,hashmap扩容复制迁移数据性能影响怎么解决
Process finished with exit code 0
层次聚类
def agglomerative(embeddings):
global questions2
clustering_model = AgglomerativeClustering(
n_clusters=None, distance_threshold=0.85
) # , affinity='cosine', linkage='average', distance_threshold=0.4)
clustering_model.fit(embeddings.cpu())
cluster_assignment = clustering_model.labels_
clustered_sentences = {}
for sentence_id, cluster_id in enumerate(cluster_assignment):
if cluster_id not in clustered_sentences:
clustered_sentences[cluster_id] = []
clustered_sentences[cluster_id].append(questions2[sentence_id])
for i, cluster in clustered_sentences.items():
print("Cluster ", i + 1)
print(cluster)
print("")
结果
Cluster 3
[‘sychorinized的作用,解决什么问题’, ‘sychorinized作用在方法和对象上的区别’, ‘接口和抽象类的区别’, ‘还有没有其他区别’, ‘二面’, ‘数据库有哪些索引’]
Cluster 6
[‘还有没有用过其他并发同步技术’, ‘版本号方案中版本号没有修改怎么办’, ‘函数重载为什么不用方法的返回值实现重载’, ‘怎么知道调用的重载的哪个方法’, ‘微博/微信的后台推送机制是怎么实现的’, ‘为什么要用redis不用mysql’]
Cluster 7
[‘CAS实现原理’, ‘CAS会自旋多久’]
Cluster 5
[‘ABA问题是什么’, ‘为什么会有ABA问题’]
Cluster 4
[‘Exception受检异常和非受检的区别’, ‘受检异常和非受检异常各在什么情况下使用’, ‘编程中如何用好两个异常’]
Cluster 8
[‘JVM内存泄漏是什么’, ‘什么情况会导致内存泄漏’, ‘jvm的结构’]
Cluster 2
[‘hashmap的底层结构’, ‘hashmap插入数据流程’, ‘hashmap解决冲突方法’, ‘hashmap数量特别多怎么解决插入和查找性能’, ‘100w数据的hashmap扩容复制会有什么问题’, ‘hashmap是怎么扩容的’, ‘hashmap的桶值是怎么计算出的’, ‘回到100w,hashmap扩容复制迁移数据性能影响怎么解决’, ‘手撕:双向链表插入+层序遍历二叉树’]
Cluster 1
[‘创新创业大赛的项目是什么,几个人写的’, ‘大学有做过哪些不是个人写的是协作写的项目’, ‘怎么学习一门新技术’, ‘怎么评价你自己是一个怎么样的人’]