语义检索系统召回阶段in_batch_negative训练语义模型

import os
import random
import time
import numpy as np
import paddle
from functools import partial
from paddlenlp.utils.log import logger
from paddlenlp.data import Tuple, Pad
from paddlenlp.datasets import load_dataset, MapDataset
from paddlenlp.transformers import AutoModel, AutoTokenizer
from paddlenlp.transformers import LinearDecayWithWarmup
import paddle.nn.functional as F

import paddle.nn  as nn
import abc
import hnswlib

这个用于无监督训练之后,导入无监督训练后的模型,继续进行语义相似性训练

class SemanticIndexBase(nn.Layer):
    def __init__(self, pretrained_model, dropout=None, output_emb_size=None):
        super().__init__()
        self.ptm = pretrained_model
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
        self.output_emb_size = output_emb_size
        if output_emb_size > 0:
            weight_attr = paddle.ParamAttr(
                initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
            self.emb_reduce_linear = paddle.nn.Linear(768,
                                                      output_emb_size,
                                                      weight_attr=weight_attr)
    def get_pooled_embedding(self,
                             input_ids,
                             token_type_ids=None,
                             position_ids=None,
                             attention_mask=None):
        _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,#池化输出
                                    attention_mask)
        if self.output_emb_size > 0:
            cls_embedding = self.emb_reduce_linear(cls_embedding)
        cls_embedding = self.dropout(cls_embedding)
        cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)#向量单位化(n,d)
        return cls_embedding
    def get_semantic_embedding(self, data_loader):
        self.eval()
        with paddle.no_grad():
            for batch_data in data_loader:
                input_ids, token_type_ids = batch_data
                text_embeddings = self.get_pooled_embedding(
                    input_ids, token_type_ids=token_type_ids)
                yield text_embeddings#返回的事输入的向量表示(n,d)
    def cosine_sim(self,
                   query_input_ids,
                   title_input_ids,
                   query_token_type_ids=None,
                   query_position_ids=None,
                   query_attention_mask=None,
                   title_token_type_ids=None,
                   title_position_ids=None,
                   title_attention_mask=None):
        query_cls_embedding = self.get_pooled_embedding(query_input_ids,
                                                        query_token_type_ids,
                                                        query_position_ids,
                                                        query_attention_mask)
        title_cls_embedding = self.get_pooled_embedding(title_input_ids,
                                                        title_token_type_ids,
                                                        title_position_ids,
                                                        title_attention_mask)
        cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
                                axis=-1)#(n,)返回qeury和title向量夹角的余弦值
        return cosine_sim
    @abc.abstractmethod#抽象类注解,需要子类重写
    def forward(self):
        pass

class SemanticIndexBatchNeg(SemanticIndexBase):
    def __init__(self,
                 pretrained_model,
                 dropout=None,
                 margin=0.3,
                 scale=30,
                 output_emb_size=None):
        super().__init__(pretrained_model, dropout, output_emb_size)
        self.margin = margin
        # 用来放大余弦相似度值
        self.sacle = scale
    def forward(self,
                query_input_ids,
                title_input_ids,
                query_token_type_ids=None,
                query_position_ids=None,
                query_attention_mask=None,
                title_token_type_ids=None,
                title_position_ids=None,
                title_attention_mask=None):
        query_cls_embedding = self.get_pooled_embedding(query_input_ids,
                                                        query_token_type_ids,
                                                        query_position_ids,
                                                        query_attention_mask)
        title_cls_embedding = self.get_pooled_embedding(title_input_ids,
                                                        title_token_type_ids,
                                                        title_position_ids,
                                                        title_attention_mask)
        #返回的每行是query和批次内的title的余弦相似度,每列是title和批次内query的余弦相似
        # 对角线上的是当前样本对的query和title,不是对角线上的是当前样本中query和别人的title的相似度
        #所以对角线上是正样本,其他是负样本
        cosine_sim = paddle.matmul(query_cls_embedding,#(n,d)@(d,n)=(n,n)
                                   title_cls_embedding,
                                   transpose_y=True)
        # substract margin from all positive samples cosine_sim()
        margin_diag = paddle.full(shape=[query_cls_embedding.shape[0]],
                                  fill_value=self.margin,
                                  dtype=paddle.get_default_dtype())
        cosine_sim = cosine_sim - paddle.diag(margin_diag)#这样对角线上会减去固定值,其他位置不变
        cosine_sim *= self.sacle#放大余弦相似
        labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64')#真实标签
        labels = paddle.reshape(labels, shape=[-1, 1])#(n,1)
        loss = F.cross_entropy(input=cosine_sim, label=labels)#多元交叉熵,(pred,label)
        return loss

def read_text_pair(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = line.rstrip().split("\t")
            if len(data) != 2:
                continue
            yield {'text_a': data[0], 'text_b': data[1]}

def convert_example(example,
                    tokenizer,
                    max_seq_length=512,
                    pad_to_max_seq_len=False):
    result = []
    for key, text in example.items():
        encoded_inputs = tokenizer(text=text,
                                   max_seq_len=max_seq_length,
                                   pad_to_max_seq_len=pad_to_max_seq_len)
        input_ids = encoded_inputs["input_ids"]
        token_type_ids = encoded_inputs["token_type_ids"]
        result += [input_ids, token_type_ids]
    return result

def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):
    if trans_fn:
        dataset = dataset.map(trans_fn)
    shuffle = True if mode == 'train' else False
    batch_sampler = paddle.io.BatchSampler(dataset,
                                           batch_size=batch_size,
                                           shuffle=shuffle)
    return paddle.io.DataLoader(dataset=dataset,
                                batch_sampler=batch_sampler,
                                collate_fn=batchify_fn,
                                return_list=True)

def gen_id2corpus(corpus_file):
    id2corpus = {}
    with open(corpus_file, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            id2corpus[idx] = line.rstrip()
    return id2corpus

def gen_text_file(similar_text_pair_file):
    text2similar_text = {}
    texts = []
    with open(similar_text_pair_file, 'r', encoding='utf-8') as f:
        for line in f:
            splited_line = line.rstrip().split("\t")
            if len(splited_line) != 2:
                continue
            text, similar_text = splited_line
            if not text or not similar_text:
                continue
            text2similar_text[text] = similar_text
            texts.append({"text": text})
    return texts, text2similar_text

output_emb_size=256
hnsw_max_elements=1000000
hnsw_ef=100
hnsw_m=100

def build_index( data_loader, model):
    index = hnswlib.Index(
        space='ip',
        dim=output_emb_size if output_emb_size > 0 else 768)
    index.init_index(max_elements=hnsw_max_elements,
                     ef_construction=hnsw_ef,
                     M=hnsw_m)
    index.set_ef(hnsw_ef)
    index.set_num_threads(10)
    logger.info("start build index..........")
    all_embeddings = []
    for text_embeddings in model.get_semantic_embedding(data_loader):
        all_embeddings.append(text_embeddings.numpy())
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    index.add_items(all_embeddings)
    logger.info("Total index number:{}".format(index.get_current_count()))
    return index

save_dir='./checkpoint/yysy/'
max_seq_length=64
batch_size=6
model_name_or_path='rocketqa-zh-base-query-encoder'
learning_rate=5e-5
weight_decay=0.0
epochs=5
warmup_proportion=0.0
init_from_ckpt='./checkpoint/yysy/yysy_recall_best_2.pdparams'
seed=1000
device='gpu'
train_set_file='./datasets/yysy/recall/train.csv'
dev_set_file='./datasets/yysy/recall/dev_s.csv'
margin=0.2
scale=30
corpus_file='./datasets/yysy/recall/corpus_s.csv'
similar_text_pair_file='./datasets/yysy/recall/dev_s.csv'
recall_result_dir='./recall_result_dir/'
recall_result_file='batch_neg_recall_result.txt'
recall_num=20
evaluate_result='batch_neg_evaluate_result.txt'

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)

@paddle.no_grad()
def evaluate(model, corpus_data_loader, query_data_loader, recall_result_file,
             text_list, id2corpus):
    def recall(rs, N=10):
        recall_flags = [np.sum(r[0:N]) for r in rs]
        return np.mean(recall_flags)
    recall_num=20
    similar_text_pair_file='./datasets/yysy/recall/dev_s.csv'
    inner_model = model._layers
    final_index = build_index(corpus_data_loader, inner_model)#构建索引库
    query_embedding = inner_model.get_semantic_embedding(query_data_loader)#查询向量
    with open(recall_result_file, 'w', encoding='utf-8') as f:#把召回写入文件
        for batch_index, batch_query_embedding in enumerate(query_embedding):
            recalled_idx, cosine_sims = final_index.knn_query(
                batch_query_embedding.numpy(), recall_num)
            batch_size = len(cosine_sims)
            for row_index in range(batch_size):
                text_index = batch_size * batch_index + row_index
                for idx, doc_idx in enumerate(recalled_idx[row_index]):
                    f.write("{}\t{}\t{}\n".format(
                        text_list[text_index]["text"], id2corpus[doc_idx],
                        1.0 - cosine_sims[row_index][idx]))
    text2similar = {}#相似文本对字典
    with open(similar_text_pair_file, 'r', encoding='utf-8') as f:
        for line in f:
            text, similar_text = line.rstrip().split("\t")
            text2similar[text] = similar_text
    rs = []
    with open(recall_result_file, 'r', encoding='utf-8') as f:
        relevance_labels = []
        for index, line in enumerate(f):
            if index % recall_num == 0 and index != 0:#用来保存一个文本的召回文本的标记,召回到相似文本标记1
                rs.append(relevance_labels)
                relevance_labels = []
            text, recalled_text, cosine_sim = line.rstrip().split("\t")
            if text == recalled_text:
                continue
            if text2similar[text] == recalled_text:#找到相似的,设置标记1,表示找到
                relevance_labels.append(1)
            else:
                relevance_labels.append(0)#不相同就设置0,表示没找到
    recall_N = []
    recall_num = [1, 5, 10, 20, 50]
    for topN in recall_num:
        R = round(100 * recall(rs, N=topN), 3)
        recall_N.append(str(R))
    # evaluate_result_file = os.path.join(recall_result_dir,
    #                                     evaluate_result)
    # result = open(evaluate_result_file, 'a')
    res = []
    timestamp = time.strftime('%Y%m%d-%H%M%S', time.localtime())
    res.append(timestamp)
    for key, val in zip(recall_num, recall_N):
        print('recall@{}={}'.format(key, val))
        res.append(str(val))
    # result.write('\t'.join(res) + '\n')
    print(recall_N)
    return float(recall_N[1])

paddle.set_device(device)

  set_seed(seed)

train_ds = load_dataset(read_text_pair,
                        data_path=train_set_file,
                        lazy=False)

pretrained_model = AutoModel.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         max_seq_length=max_seq_length)
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
        ),  # query_input
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
        ),  # query_segment
    Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
        ),  # title_input
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
        ),  # tilte_segment
): [data for data in fn(samples)]
train_data_loader = create_dataloader(train_ds,
                                      mode='train',
                                      batch_size=batch_size,
                                      batchify_fn=batchify_fn,
                                      trans_fn=trans_func)

for i in train_data_loader:
    print(i)
    break

model = SemanticIndexBatchNeg(pretrained_model,
                                  margin=margin,
                                  scale=scale,
                                  output_emb_size=output_emb_size)

if init_from_ckpt and os.path.isfile(init_from_ckpt):
    state_dict = paddle.load(init_from_ckpt)
    model.set_dict(state_dict)
    print("warmup from:{}".format(init_from_ckpt))

model = paddle.DataParallel(model)

batchify_fn_dev = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
            ),  # text_input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
            ),  # text_segment
    ): [data for data in fn(samples)]

id2corpus = gen_id2corpus(corpus_file)

corpus_list = [{idx: text} for idx, text in id2corpus.items()]
corpus_ds = MapDataset(corpus_list)

corpus_data_loader = create_dataloader(corpus_ds,#索引库
                                       mode='predict',
                                       batch_size=batch_size,
                                       batchify_fn=batchify_fn_dev,
                                       trans_fn=trans_func)

 text_list, text2similar_text = gen_text_file(similar_text_pair_file)
query_ds = MapDataset(text_list)
query_data_loader = create_dataloader(query_ds,#查询
                                      mode='predict',
                                      batch_size=batch_size,
                                      batchify_fn=batchify_fn_dev,
                                      trans_fn=trans_func)

if not os.path.exists(recall_result_dir):#召回结果文件
    os.mkdir(recall_result_dir)
recall_result_file = os.path.join(recall_result_dir,
                                  recall_result_file)

num_training_steps = len(train_data_loader) * epochs

lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps,
                                         warmup_proportion)
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)

def train(model,dataloader,optimizer,scheduler,global_step):
    model.train()
    total_loss=0.
    for step, batch in enumerate(dataloader, start=1):
        query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch
        loss = model(
            query_input_ids=query_input_ids,
            title_input_ids=title_input_ids,
            query_token_type_ids=query_token_type_ids,
            title_token_type_ids=title_token_type_ids)
        loss.backward()#反向传播
        optimizer.step()#梯度更新
        scheduler.step()#更新学习率
        optimizer.clear_grad()#梯度归0
        global_step += 1
        total_loss += loss.item()
        if global_step % 30 == 0 and global_step!=0:
            print("global_step %d - batch: %d -loss: %.5f "
                  %(global_step,step,loss.item()))
    avg_loss=total_loss/len(dataloader)
    return avg_loss,global_step

def train_epochs(epochs,save_path):
    global_step = 0#全局步
    best_recall=0.0
    for epoch in range(1,epochs+1):
        avg_loss,global_step = train(model,train_data_loader,optimizer,\
        lr_scheduler,global_step)
        print("epoch:%d - global_step:%d - avg_loss: %.4f - best_recall: %.4f -lr:%.8f" \
          % (epoch,global_step,avg_loss,best_recall,optimizer.get_lr()))
        recall_5 = evaluate(model, corpus_data_loader, query_data_loader,
                                recall_result_file, text_list, id2corpus)
        if recall_5>best_recall:
            paddle.save(model.state_dict(),\
                save_path+'yysy_in_batch_negative_recall_1.pdparams')
            tokenizer.save_pretrained(save_path)
            best_recall=recall_5

if not os.path.exists(save_dir): os.makedirs(save_dir)
train_epochs(epochs,save_dir)

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值