保险问答系统之模型训练召回评估

from functools import partial
import os
import sys
import random
import time
from scipy import stats
import numpy as np
import paddle
import paddle.nn.functional as F
import paddle.nn as nn
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset,MapDataset
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.transformers import AutoTokenizer, AutoModel

import hnswlib
from paddlenlp.utils.log import logger

class SimCSE(nn.Layer):
    def __init__(self,
                 pretrained_model,
                 dropout=None,
                 margin=0.0,
                 scale=20,
                 output_emb_size=None):
        super().__init__()
        self.ptm = pretrained_model
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
        self.output_emb_size = output_emb_size
        if output_emb_size > 0:
            weight_attr = paddle.ParamAttr(
                initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
            self.emb_reduce_linear = paddle.nn.Linear(768,
                                                      output_emb_size,
                                                      weight_attr=weight_attr)
        self.margin = margin
        self.sacle = scale
        self.classifier = nn.Linear(output_emb_size, 2)
        self.rdrop_loss = paddlenlp.losses.RDropLoss()
    @paddle.jit.to_static(input_spec=[
        paddle.static.InputSpec(shape=[None, None], dtype='int64'),
        paddle.static.InputSpec(shape=[None, None], dtype='int64')
    ])
    def get_pooled_embedding(self,
                             input_ids,
                             token_type_ids=None,
                             position_ids=None,
                             attention_mask=None,
                             with_pooler=True):
        sequence_output, cls_embedding = self.ptm(input_ids, token_type_ids,
                                                  position_ids, attention_mask)
        if with_pooler == False:
            cls_embedding = sequence_output[:, 0, :]#池化输出(n,d)
        if self.output_emb_size > 0:
            cls_embedding = self.emb_reduce_linear(cls_embedding)
        cls_embedding = self.dropout(cls_embedding)
        cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)#向量单位化(n,d)
        return cls_embedding
    def get_semantic_embedding(self, data_loader):
        self.eval()
        with paddle.no_grad():
            for batch_data in data_loader:
                input_ids, token_type_ids = batch_data
                text_embeddings = self.get_pooled_embedding(
                    input_ids, token_type_ids=token_type_ids)
                yield text_embeddings#文本的向量表示
    def cosine_sim(self,
                   query_input_ids,
                   title_input_ids,
                   query_token_type_ids=None,
                   query_position_ids=None,
                   query_attention_mask=None,
                   title_token_type_ids=None,
                   title_position_ids=None,
                   title_attention_mask=None,
                   with_pooler=True):
        query_cls_embedding = self.get_pooled_embedding(query_input_ids,
                                                        query_token_type_ids,
                                                        query_position_ids,
                                                        query_attention_mask,
                                                        with_pooler=with_pooler)
        title_cls_embedding = self.get_pooled_embedding(title_input_ids,
                                                        title_token_type_ids,
                                                        title_position_ids,
                                                        title_attention_mask,
                                                        with_pooler=with_pooler)
        cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
                                axis=-1)#向量的点积,因为是单位向量,就是向量间夹角的余弦值
        return cosine_sim
    def forward(self,
                query_input_ids,
                title_input_ids,
                query_token_type_ids=None,
                query_position_ids=None,
                query_attention_mask=None,
                title_token_type_ids=None,
                title_position_ids=None,
                title_attention_mask=None):
        query_cls_embedding = self.get_pooled_embedding(query_input_ids,
                                                        query_token_type_ids,
                                                        query_position_ids,
                                                        query_attention_mask)
        title_cls_embedding = self.get_pooled_embedding(title_input_ids,
                                                        title_token_type_ids,
                                                        title_position_ids,
                                                        title_attention_mask)
        logits1 = self.classifier(query_cls_embedding)#(n,2)
        logits2 = self.classifier(title_cls_embedding)
        kl_loss = self.rdrop_loss(logits1, logits2)#因为query和title语义相似,所以这里是训练两者损失接近0
        cosine_sim = paddle.matmul(query_cls_embedding,#(n,d)@(d,n)=(n,n)
                                   title_cls_embedding,
                                   transpose_y=True)
        margin_diag = paddle.full(shape=[query_cls_embedding.shape[0]],
                                  fill_value=self.margin,
                                  dtype=paddle.get_default_dtype())
        cosine_sim = cosine_sim - paddle.diag(margin_diag)#余弦相似度
        cosine_sim *= self.sacle
        labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64')
        labels = paddle.reshape(labels, shape=[-1, 1])#(n,1)
        loss = F.cross_entropy(input=cosine_sim, label=labels)#
        return loss, kl_loss

labels = paddle.arange(0,10, dtype='int64')
labels = paddle.reshape(labels, shape=[-1, 1])#(n,1)
labels

def read_simcse_text(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = line.rstrip()
            yield {'text_a': data, 'text_b': data}

def read_text_pair(data_path, is_test=False):
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = line.rstrip().split("\t")
            if is_test == True:
                if len(data) != 3:
                    continue
                yield {'text_a': data[0], 'text_b': data[1], 'label': data[2]}
            else:
                if len(data) != 2:
                    continue
                yield {'text_a': data[0], 'text_b': data[1]}

def convert_example(example, tokenizer, max_seq_length=512, do_evalute=False):
    result = []
    for key, text in example.items():
        if 'label' in key:
            # do_evaluate
            result += [example['label']]
        else:
            # do_train
            encoded_inputs = tokenizer(text=text, max_seq_len=max_seq_length)
            input_ids = encoded_inputs["input_ids"]
            token_type_ids = encoded_inputs["token_type_ids"]
            result += [input_ids, token_type_ids]
    return result

def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):
    if trans_fn:
        dataset = dataset.map(trans_fn)
    shuffle = True if mode == 'train' else False
    if mode == 'train':
        batch_sampler = paddle.io.DistributedBatchSampler(dataset,
                                                          batch_size=batch_size,
                                                          shuffle=shuffle)
    else:
        batch_sampler = paddle.io.BatchSampler(dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffle)
    return paddle.io.DataLoader(dataset=dataset,
                                batch_sampler=batch_sampler,
                                collate_fn=batchify_fn,
                                return_list=True)

def word_repetition(input_ids, token_type_ids, dup_rate=0.32):
    input_ids = input_ids.numpy().tolist()
    token_type_ids = token_type_ids.numpy().tolist()
    batch_size, seq_len = len(input_ids), len(input_ids[0])
    repetitied_input_ids = []
    repetitied_token_type_ids = []
    rep_seq_len = seq_len
    for batch_id in range(batch_size):
        cur_input_id = input_ids[batch_id]
        actual_len = np.count_nonzero(cur_input_id)
        dup_word_index = []
        # If sequence length is less than 5, skip it
        if (actual_len > 5):
            dup_len = random.randint(a=0, b=max(2, int(dup_rate * actual_len)))
            # Skip cls and sep position
            dup_word_index = random.sample(list(range(1, actual_len - 1)),
                                           k=dup_len)
        r_input_id = []
        r_token_type_id = []
        for idx, word_id in enumerate(cur_input_id):
            # Insert duplicate word
            if idx in dup_word_index:#选中的词被添加两次
                r_input_id.append(word_id)
                r_token_type_id.append(token_type_ids[batch_id][idx])
            r_input_id.append(word_id)
            r_token_type_id.append(token_type_ids[batch_id][idx])
        after_dup_len = len(r_input_id)
        repetitied_input_ids.append(r_input_id)
        repetitied_token_type_ids.append(r_token_type_id)
        if after_dup_len > rep_seq_len:
            rep_seq_len = after_dup_len
    # Padding the data to the same length
    for batch_id in range(batch_size):
        after_dup_len = len(repetitied_input_ids[batch_id])
        pad_len = rep_seq_len - after_dup_len
        repetitied_input_ids[batch_id] += [0] * pad_len
        repetitied_token_type_ids[batch_id] += [0] * pad_len
    return paddle.to_tensor(repetitied_input_ids,
                            dtype='int64'), paddle.to_tensor(
                                repetitied_token_type_ids, dtype='int64')

save_dir='./checkpoint/bxqa/'
max_seq_length=64
batch_size=16
output_emb_size=256
learning_rate=5e-5
weight_decay=0.0
epochs=3
warmup_proportion=0.0
init_from_ckpt=None
seed=1000
device='gpu'
train_set_file='./datasets/bxqa/train_aug.csv'
margin=0.0
scale=20
dropout=0.2
dup_rate=0.1
infer_with_fc_pooler=True
rdrop_coef=0.1

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)

@paddle.no_grad()
def do_evaluate(model, tokenizer, data_loader, with_pooler=False):
    model.eval()
    total_num = 0
    spearman_corr = 0.0
    sims = []
    labels = []
    for batch in data_loader:
        query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids, label = batch
        total_num += len(label)
        query_cls_embedding = model.get_pooled_embedding(
            query_input_ids, query_token_type_ids, with_pooler=with_pooler)
        title_cls_embedding = model.get_pooled_embedding(title_input_ids, title_token_type_ids, with_pooler=with_pooler)
        cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding, axis=-1)
        sims.append(cosine_sim.numpy())
        labels.append(label.numpy())
    sims = np.concatenate(sims, axis=0)
    labels = np.concatenate(labels, axis=0)
    spearman_corr = stats.spearmanr(labels, sims).correlation
    return spearman_corr, total_num

paddle.set_device(device)

set_seed(seed)

train_ds = load_dataset(
    read_text_pair, data_path=train_set_file,is_test=False, lazy=False)
model_name_or_path='rocketqa-zh-dureader-query-encoder'
pretrained_model = AutoModel.from_pretrained(
   model_name_or_path,
   hidden_dropout_prob=dropout,
   attention_probs_dropout_prob=dropout)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

trans_func = partial(
    convert_example,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length)
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  # query_input
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"),  # query_segment
    Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  # title_input
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"),  # tilte_segment
): [data for data in fn(samples)]

  train_data_loader = create_dataloader(
        train_ds,
        mode='train',
        batch_size=batch_size,
        batchify_fn=batchify_fn,
        trans_fn=trans_func)

for i in train_data_loader:
    print(i)
    break

import paddlenlp

model = SimCSE(
    pretrained_model,
    margin=margin,
    scale=scale,
    output_emb_size=output_emb_size)

if init_from_ckpt and os.path.isfile(init_from_ckpt):
        state_dict = paddle.load(init_from_ckpt)
        model.set_dict(state_dict)
        print("warmup from:{}".format(init_from_ckpt))

num_training_steps = len( train_data_loader) * epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps,\
                                    warmup_proportion)
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)#不带bias,norm

def train(model,dataloader,optimizer,scheduler,global_step):
    model.train()
    total_loss=0.
    for step, batch in enumerate(dataloader, start=1):
        #正常情况query和title语义相似
        query_input_ids, query_token_type_ids, title_input_ids,title_token_type_ids = batch
        if random.random() < 0.2:#有0.2的几率做word_repetition训练,其他情况正常训练
            title_input_ids,title_token_type_ids=query_input_ids,query_token_type_ids#这时候query和title一样
            query_input_ids, query_token_type_ids = word_repetition(\
              query_input_ids, query_token_type_ids,dup_rate)
            title_input_ids, title_token_type_ids = word_repetition(\
                title_input_ids,title_token_type_ids,dup_rate)
        loss, kl_loss = model(
        query_input_ids=query_input_ids,
        title_input_ids=title_input_ids,
        query_token_type_ids=query_token_type_ids,
        title_token_type_ids=title_token_type_ids)
        loss = loss + kl_loss *rdrop_coef# rdrop_coef:0.1
        loss.backward()#反向传播
        optimizer.step()#梯度更新
        optimizer.clear_grad()#梯度归0
        scheduler.step()#更新学习率
        global_step += 1
        total_loss += loss.item()
        if global_step % 30 == 0 and global_step!=0:
            print("global_step %d - batch: %d -loss: %.5f -dup_rate %.4f"\
                        %(global_step,step,loss.item(),dup_rate))
    avg_loss=total_loss/len(dataloader)
    return avg_loss,global_step

def train_epochs(epochs,save_path):
    global_step = 0#全局步
    best_loss=float('inf')#最好的损失
    for epoch in range(1,epochs+1):
        avg_loss,global_step = train(model,train_data_loader,optimizer,\
        lr_scheduler,global_step)
        print("epoch:%d - global_step:%d - avg_loss: %.4f -best_loss:%.4f -lr:%.8f" \
          % (epoch,global_step,avg_loss,best_loss,optimizer.get_lr()))
        if avg_loss < best_loss:
            paddle.save(model.state_dict(),\
                os.path.join(save_dir,'bqqa_best_%d.pdparams'  % global_step) )
            tokenizer.save_pretrained(save_path)
            best_loss=avg_loss
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
train_epochs(epochs,save_dir)

def convert_example_test(example,
                         tokenizer,
                         max_seq_length=512,
                         pad_to_max_seq_len=False):
    result = []
    for key, text in example.items():
        encoded_inputs = tokenizer(text=text,
                                   max_seq_len=max_seq_length,
                                   pad_to_max_seq_len=pad_to_max_seq_len)
        input_ids = encoded_inputs["input_ids"]
        token_type_ids = encoded_inputs["token_type_ids"]
        result += [input_ids, token_type_ids]
    return result

def gen_id2corpus(corpus_file):
    id2corpus = {}
    with open(corpus_file, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            id2corpus[idx] = line.rstrip()
    return id2corpus

def gen_text_file(similar_text_pair_file):
    text2similar_text = {}
    texts = []
    with open(similar_text_pair_file, 'r', encoding='utf-8') as f:
        for line in f:
            splited_line = line.rstrip().split("\t")
            if len(splited_line) != 2:
                continue
            text, similar_text =  splited_line
            if not text or not similar_text:
                continue
            text2similar_text[text] = similar_text
            texts.append({"text": text})
    return texts, text2similar_text

hnsw_max_elements=1000000
hnsw_ef=100
hnsw_m=100

def build_index( data_loader, model):
    index = hnswlib.Index(
        space='ip',
        dim=output_emb_size if output_emb_size > 0 else 768)
    index.init_index(max_elements=hnsw_max_elements,
                     ef_construction=hnsw_ef,
                     M=hnsw_m)
    index.set_ef(hnsw_ef)
    index.set_num_threads(10)
    logger.info("start build index..........")
    all_embeddings = []
    for text_embeddings in model.get_semantic_embedding(data_loader):
        all_embeddings.append(text_embeddings.numpy())
    all_embeddings = np.concatenate(all_embeddings, axis=0)#合并文本向量表示
    index.add_items(all_embeddings)
    logger.info("Total index number:{}".format(index.get_current_count()))
    return index

corpus_file='./datasets/bxqa/corpus.csv'
similar_text_pair_file='./datasets/bxqa/test_pair.csv'
recall_result_dir='recall_result_dir'
recall_result_file='bxqa_recall_result_1.txt'
params_path='./checkpoint/bxqa/bqqa_best_570.pdparams'
max_seq_length=64
batch_size=64
output_emb_size=256
recall_num=10

trans_func = partial(convert_example_test,
                         tokenizer=tokenizer,
                         max_seq_length=max_seq_length)

 batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
            ),  # text_input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
            ),  # text_segment
    ): [data for data in fn(samples)]

if  params_path and os.path.isfile(params_path):
    state_dict = paddle.load(params_path)
    model.set_dict(state_dict)
    logger.info("Loaded parameters from %s" % params_path)
else:
    raise ValueError(
        "Please set --params_path with correct pretrained model file")

 id2corpus = gen_id2corpus(corpus_file)

corpus_list = [{idx: text} for idx, text in id2corpus.items()]
corpus_ds = MapDataset(corpus_list)

corpus_data_loader = create_dataloader(corpus_ds,
                                           mode='predict',
                                           batch_size=batch_size,
                                           batchify_fn=batchify_fn,
                                           trans_fn=trans_func)

inner_model = model
final_index = build_index(corpus_data_loader, inner_model)

text_list, text2similar_text = gen_text_file(similar_text_pair_file)

query_ds = MapDataset(text_list)
query_data_loader = create_dataloader(query_ds,
                                      mode='predict',
                                      batch_size=batch_size,
                                      batchify_fn=batchify_fn,
                                      trans_fn=trans_func)

query_embedding = inner_model.get_semantic_embedding(query_data_loader)

 if not os.path.exists(recall_result_dir):
        os.mkdir(arecall_result_dir)
recall_result_file = os.path.join(recall_result_dir,
                                 recall_result_file)

with open(recall_result_file, 'w', encoding='utf-8') as f:
        for batch_index, batch_query_embedding in enumerate(query_embedding):
            recalled_idx, cosine_sims = final_index.knn_query(#返回10个距离相近的文本
                batch_query_embedding.numpy(), recall_num)
            batch_size = len(cosine_sims)
            for row_index in range(batch_size):
                text_index =batch_size * batch_index + row_index
                for idx, doc_idx in enumerate(recalled_idx[row_index]):
                    f.write("{}\t{}\t{}\n".format(
                        text_list[text_index]["text"], id2corpus[doc_idx],
                        1.0 - cosine_sims[row_index][idx]))

def recall(rs, N=10):
    #np.sum(r[0:N]) 前n个召回中有1个1,这个位置就是1,反应的事前N个召回能召回准确数据的概率
    recall_flags = [np.sum(r[0:N]) for r in rs]
    return np.mean(recall_flags)

similar_text_pair='./datasets/bxqa/test_pair.csv'

recall_num=10

text2similar = {}
with open(similar_text_pair, 'r', encoding='utf-8') as f:
    for line in f:
        text, similar_text = line.rstrip().split("\t")
        text2similar[text] = similar_text
print(len(text2similar.values()))
rs = []
with open(recall_result_file, 'r', encoding='utf-8') as f:
    relevance_labels = []#用来存放每个query前N个召回中每次召回的标记,如果当次召回到正确的文本,标记为1
    for index, line in enumerate(f):
        if index %  recall_num == 0 and index != 0:#够召回数就进行下个query的召回
            rs.append(relevance_labels)
            relevance_labels = []
        text, recalled_text, cosine_sim = line.rstrip().split("\t")#文件中的每行,原文本,召回文本,相似度
        if text2similar[text] == recalled_text:#如果召回到正确的语义相似文本,标记为1,表示找到
            relevance_labels.append(1)
        else:
            relevance_labels.append(0)#否则 标记0,表示没找到

recall_N = []
recall_num = [1, 5, 10]
result = open('result.tsv', 'a')
res = []
for topN in recall_num:
    R = round(100 * recall(rs, N=topN), 3)
    recall_N.append(str(R))
for key, val in zip(recall_num, recall_N):
    print('recall@{}={}'.format(key, val))
    res.append(str(val))
result.write('\t'.join(res) + '\n')
result.close()

 

  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值