【项目实训】基于bge-large的自然语言文本聚类

项目需求:我们需要对面试经验信息进行总结,但是面试经验的数据量非常大,直接传给大模型会面临以下问题

  • 大模型输入长度有限:大部分的大模型输入长度为2048token,远远低于我们的文本量,即便是一篇博客的内容,也有可能超过上限
  • 面对长文本的总结信息受限:由于我们用的是6B模型,并不如目前常见的大模型如gpt4、通义千问等效果好,输入文本越长,总结能力越差。

根据上述问题,我们可以先对面试经验中的句子进行聚类,将相似的句子分类在一起,再让大模型对类似的信息进行概括,这样总结能力可能会好一点。

实现

使用 transformers 框架加载本地的 bge-large-zh模型,对输入文本进行嵌入,再使用 sklearn 中的聚类算法对句子向量聚类。

Embedding

from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import torch
import torch.nn.functional as F


def embedding(sentences):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # 加载模型
    tokenizer = AutoTokenizer.from_pretrained('../BAAI/bge-large-zh')
    model = AutoModel.from_pretrained('../BAAI/bge-large-zh').to(device)
    # tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
    # model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(device)

    # Tokenize sentences
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)

    # Compute embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # 池化和归一化
    sentence_embeddings = model_output[0][:, 0]
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).to(device)
    return sentence_embeddings

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

X = embedding(questions)

结果

class 0 : 100w数据的hashmap扩容复制会有什么问题
class 0 : 回到100w,hashmap扩容复制迁移数据性能影响怎么解决
class 1 : sychorinized的作用,解决什么问题
class 1 : sychorinized作用在方法和对象上的区别
class 2 : 创新创业大赛的项目是什么,几个人写的
class 2 : 大学有做过哪些不是个人写的是协作写的项目
class 3 : hashmap的底层结构
class 3 : hashmap插入数据流程
class 3 : hashmap解决冲突方法
class 3 : hashmap数量特别多怎么解决插入和查找性能
class 3 : hashmap是怎么扩容的
class 3 : hashmap的桶值是怎么计算出的
class 3 : 手撕:双向链表插入+层序遍历二叉树
class 4 : CAS实现原理
class 4 : CAS会自旋多久
class 5 : Exception受检异常和非受检的区别
class 5 : 受检异常和非受检异常各在什么情况下使用
class 5 : 编程中如何用好两个异常
class 6 : ABA问题是什么
class 6 : 为什么会有ABA问题
class 7 : jvm的结构
class 8 : JVM内存泄漏是什么
class 8 : 什么情况会导致内存泄漏
class 9 : 版本号方案中版本号没有修改怎么办
class 9 : 函数重载为什么不用方法的返回值实现重载
class 9 : 怎么知道调用的重载的哪个方法
class 9 : 为什么要用redis不用mysql
class 10 : 二面
class 11 : 怎么学习一门新技术
class 11 : 怎么评价你自己是一个怎么样的人
class 12 : 还有没有用过其他并发同步技术
class 12 : 微博/微信的后台推送机制是怎么实现的
class 13 : 数据库有哪些索引
class 14 : 接口和抽象类的区别
class 14 : 还有没有其他区别

Process finished with exit code 0

聚类算法

k-means

k-means 算法的超参:簇的个数

from sklearn.cluster import KMeans, DBSCAN
from sklearn.cluster import AgglomerativeClustering
from transformers import AutoTokenizer, AutoModel

cluster_num = 15
# 初始化聚类模型
kmeans = KMeans(n_clusters=cluster_num)
# 进行聚类
kmeans.fit(X.cpu())

# 获取每个问题的类别
labels = kmeans.labels_

# 输出每个问题所属的类别
cluster_result = []
for i in range(cluster_num):
    cluster_result.append([])
for i in range(len(questions2)):
    # print("Question", i + 1, "belongs to cluster:", labels[i])
    cluster_result[labels[i]].append(questions2[i])
for i in range(cluster_num):
    for j in range(len(cluster_result[i])):
        print("class " + str(i) + " : " + cluster_result[i][j])

dbscan

dbscan的超参是簇内实体的最大距离

dbscan = DBSCAN(eps=0.15, min_samples=2, metric='cosine')
dbscan.fit(X.cpu())
labels = dbscan.labels_
cluster_result = dict()
    for i in range(len(questions2)):
        if labels[i] not in cluster_result:
            cluster_result[labels[i]] = []
        cluster_result[labels[i]].append(i)
    for label, sentences in cluster_result.items():
        for i in sentences:
            print("class " + str(label) + " : " + questions2[i])

结果

class 0 : sychorinized的作用,解决什么问题
class 0 : sychorinized作用在方法和对象上的区别
class -1 : 还有没有用过其他并发同步技术
class -1 : CAS实现原理
class -1 : CAS会自旋多久
class -1 : 版本号方案中版本号没有修改怎么办
class -1 : 编程中如何用好两个异常
class -1 : 接口和抽象类的区别
class -1 : 还有没有其他区别
class -1 : 手撕:双向链表插入+层序遍历二叉树
class -1 : 二面
class -1 : 创新创业大赛的项目是什么,几个人写的
class -1 : 大学有做过哪些不是个人写的是协作写的项目
class -1 : 微博/微信的后台推送机制是怎么实现的
class -1 : 为什么要用redis不用mysql
class -1 : 数据库有哪些索引
class -1 : 怎么学习一门新技术
class -1 : 怎么评价你自己是一个怎么样的人
class 1 : ABA问题是什么
class 1 : 为什么会有ABA问题
class 2 : Exception受检异常和非受检的区别
class 2 : 受检异常和非受检异常各在什么情况下使用
class 3 : JVM内存泄漏是什么
class 3 : 什么情况会导致内存泄漏
class 3 : jvm的结构
class 4 : 函数重载为什么不用方法的返回值实现重载
class 4 : 怎么知道调用的重载的哪个方法
class 5 : hashmap的底层结构
class 5 : hashmap插入数据流程
class 5 : hashmap解决冲突方法
class 5 : hashmap数量特别多怎么解决插入和查找性能
class 5 : 100w数据的hashmap扩容复制会有什么问题
class 5 : hashmap是怎么扩容的
class 5 : hashmap的桶值是怎么计算出的
class 5 : 回到100w,hashmap扩容复制迁移数据性能影响怎么解决

Process finished with exit code 0

层次聚类

def agglomerative(embeddings):
    global questions2
    clustering_model = AgglomerativeClustering(
        n_clusters=None, distance_threshold=0.85
    )  # , affinity='cosine', linkage='average', distance_threshold=0.4)
    clustering_model.fit(embeddings.cpu())
    cluster_assignment = clustering_model.labels_

    clustered_sentences = {}
    for sentence_id, cluster_id in enumerate(cluster_assignment):
        if cluster_id not in clustered_sentences:
            clustered_sentences[cluster_id] = []

        clustered_sentences[cluster_id].append(questions2[sentence_id])

    for i, cluster in clustered_sentences.items():
        print("Cluster ", i + 1)
        print(cluster)
        print("")

结果

Cluster 3
[‘sychorinized的作用,解决什么问题’, ‘sychorinized作用在方法和对象上的区别’, ‘接口和抽象类的区别’, ‘还有没有其他区别’, ‘二面’, ‘数据库有哪些索引’]
Cluster 6
[‘还有没有用过其他并发同步技术’, ‘版本号方案中版本号没有修改怎么办’, ‘函数重载为什么不用方法的返回值实现重载’, ‘怎么知道调用的重载的哪个方法’, ‘微博/微信的后台推送机制是怎么实现的’, ‘为什么要用redis不用mysql’]
Cluster 7
[‘CAS实现原理’, ‘CAS会自旋多久’]
Cluster 5
[‘ABA问题是什么’, ‘为什么会有ABA问题’]
Cluster 4
[‘Exception受检异常和非受检的区别’, ‘受检异常和非受检异常各在什么情况下使用’, ‘编程中如何用好两个异常’]
Cluster 8
[‘JVM内存泄漏是什么’, ‘什么情况会导致内存泄漏’, ‘jvm的结构’]
Cluster 2
[‘hashmap的底层结构’, ‘hashmap插入数据流程’, ‘hashmap解决冲突方法’, ‘hashmap数量特别多怎么解决插入和查找性能’, ‘100w数据的hashmap扩容复制会有什么问题’, ‘hashmap是怎么扩容的’, ‘hashmap的桶值是怎么计算出的’, ‘回到100w,hashmap扩容复制迁移数据性能影响怎么解决’, ‘手撕:双向链表插入+层序遍历二叉树’]
Cluster 1
[‘创新创业大赛的项目是什么,几个人写的’, ‘大学有做过哪些不是个人写的是协作写的项目’, ‘怎么学习一门新技术’, ‘怎么评价你自己是一个怎么样的人’]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值