import os
import random
import time
import numpy as np
import paddle
from functools import partial
from paddlenlp.utils.log import logger
from paddlenlp.data import Tuple, Pad
from paddlenlp.datasets import load_dataset, MapDataset
from paddlenlp.transformers import AutoModel, AutoTokenizer
from paddlenlp.transformers import LinearDecayWithWarmup
import paddle.nn.functional as F
import paddle.nn as nn
import abc
import hnswlib
这个用于无监督训练之后,导入无监督训练后的模型,继续进行语义相似性训练
class SemanticIndexBase(nn.Layer):
def __init__(self, pretrained_model, dropout=None, output_emb_size=None):
super().__init__()
self.ptm = pretrained_model
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
self.output_emb_size = output_emb_size
if output_emb_size > 0:
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
self.emb_reduce_linear = paddle.nn.Linear(768,
output_emb_size,
weight_attr=weight_attr)
def get_pooled_embedding(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,#池化输出
attention_mask)
if self.output_emb_size > 0:
cls_embedding = self.emb_reduce_linear(cls_embedding)
cls_embedding = self.dropout(cls_embedding)
cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)#向量单位化(n,d)
return cls_embedding
def get_semantic_embedding(self, data_loader):
self.eval()
with paddle.no_grad():
for batch_data in data_loader:
input_ids, token_type_ids = batch_data
text_embeddings = self.get_pooled_embedding(
input_ids, token_type_ids=token_type_ids)
yield text_embeddings#返回的事输入的向量表示(n,d)
def cosine_sim(self,
query_input_ids,
title_input_ids,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None):
query_cls_embedding = self.get_pooled_embedding(query_input_ids,
query_token_type_ids,
query_position_ids,
query_attention_mask)
title_cls_embedding = self.get_pooled_embedding(title_input_ids,
title_token_type_ids,
title_position_ids,
title_attention_mask)
cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
axis=-1)#(n,)返回qeury和title向量夹角的余弦值
return cosine_sim
@abc.abstractmethod#抽象类注解,需要子类重写
def forward(self):
pass
class SemanticIndexBatchNeg(SemanticIndexBase):
def __init__(self,
pretrained_model,
dropout=None,
margin=0.3,
scale=30,
output_emb_size=None):
super().__init__(pretrained_model, dropout, output_emb_size)
self.margin = margin
# 用来放大余弦相似度值
self.sacle = scale
def forward(self,
query_input_ids,
title_input_ids,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None):
query_cls_embedding = self.get_pooled_embedding(query_input_ids,
query_token_type_ids,
query_position_ids,
query_attention_mask)
title_cls_embedding = self.get_pooled_embedding(title_input_ids,
title_token_type_ids,
title_position_ids,
title_attention_mask)
#返回的每行是query和批次内的title的余弦相似度,每列是title和批次内query的余弦相似
# 对角线上的是当前样本对的query和title,不是对角线上的是当前样本中query和别人的title的相似度
#所以对角线上是正样本,其他是负样本
cosine_sim = paddle.matmul(query_cls_embedding,#(n,d)@(d,n)=(n,n)
title_cls_embedding,
transpose_y=True)
# substract margin from all positive samples cosine_sim()
margin_diag = paddle.full(shape=[query_cls_embedding.shape[0]],
fill_value=self.margin,
dtype=paddle.get_default_dtype())
cosine_sim = cosine_sim - paddle.diag(margin_diag)#这样对角线上会减去固定值,其他位置不变
cosine_sim *= self.sacle#放大余弦相似
labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64')#真实标签
labels = paddle.reshape(labels, shape=[-1, 1])#(n,1)
loss = F.cross_entropy(input=cosine_sim, label=labels)#多元交叉熵,(pred,label)
return loss
def read_text_pair(data_path):
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
data = line.rstrip().split("\t")
if len(data) != 2:
continue
yield {'text_a': data[0], 'text_b': data[1]}
def convert_example(example,
tokenizer,
max_seq_length=512,
pad_to_max_seq_len=False):
result = []
for key, text in example.items():
encoded_inputs = tokenizer(text=text,
max_seq_len=max_seq_length,
pad_to_max_seq_len=pad_to_max_seq_len)
input_ids = encoded_inputs["input_ids"]
token_type_ids = encoded_inputs["token_type_ids"]
result += [input_ids, token_type_ids]
return result
def create_dataloader(dataset,
mode='train',
batch_size=1,
batchify_fn=None,
trans_fn=None):
if trans_fn:
dataset = dataset.map(trans_fn)
shuffle = True if mode == 'train' else False
batch_sampler = paddle.io.BatchSampler(dataset,
batch_size=batch_size,
shuffle=shuffle)
return paddle.io.DataLoader(dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=batchify_fn,
return_list=True)
def gen_id2corpus(corpus_file):
id2corpus = {}
with open(corpus_file, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
id2corpus[idx] = line.rstrip()
return id2corpus
def gen_text_file(similar_text_pair_file):
text2similar_text = {}
texts = []
with open(similar_text_pair_file, 'r', encoding='utf-8') as f:
for line in f:
splited_line = line.rstrip().split("\t")
if len(splited_line) != 2:
continue
text, similar_text = splited_line
if not text or not similar_text:
continue
text2similar_text[text] = similar_text
texts.append({"text": text})
return texts, text2similar_text
output_emb_size=256
hnsw_max_elements=1000000
hnsw_ef=100
hnsw_m=100
def build_index( data_loader, model):
index = hnswlib.Index(
space='ip',
dim=output_emb_size if output_emb_size > 0 else 768)
index.init_index(max_elements=hnsw_max_elements,
ef_construction=hnsw_ef,
M=hnsw_m)
index.set_ef(hnsw_ef)
index.set_num_threads(10)
logger.info("start build index..........")
all_embeddings = []
for text_embeddings in model.get_semantic_embedding(data_loader):
all_embeddings.append(text_embeddings.numpy())
all_embeddings = np.concatenate(all_embeddings, axis=0)
index.add_items(all_embeddings)
logger.info("Total index number:{}".format(index.get_current_count()))
return index
save_dir='./checkpoint/yysy/'
max_seq_length=64
batch_size=6
model_name_or_path='rocketqa-zh-base-query-encoder'
learning_rate=5e-5
weight_decay=0.0
epochs=5
warmup_proportion=0.0
init_from_ckpt='./checkpoint/yysy/yysy_recall_best_2.pdparams'
seed=1000
device='gpu'
train_set_file='./datasets/yysy/recall/train.csv'
dev_set_file='./datasets/yysy/recall/dev_s.csv'
margin=0.2
scale=30
corpus_file='./datasets/yysy/recall/corpus_s.csv'
similar_text_pair_file='./datasets/yysy/recall/dev_s.csv'
recall_result_dir='./recall_result_dir/'
recall_result_file='batch_neg_recall_result.txt'
recall_num=20
evaluate_result='batch_neg_evaluate_result.txt'
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
@paddle.no_grad()
def evaluate(model, corpus_data_loader, query_data_loader, recall_result_file,
text_list, id2corpus):
def recall(rs, N=10):
recall_flags = [np.sum(r[0:N]) for r in rs]
return np.mean(recall_flags)
recall_num=20
similar_text_pair_file='./datasets/yysy/recall/dev_s.csv'
inner_model = model._layers
final_index = build_index(corpus_data_loader, inner_model)#构建索引库
query_embedding = inner_model.get_semantic_embedding(query_data_loader)#查询向量
with open(recall_result_file, 'w', encoding='utf-8') as f:#把召回写入文件
for batch_index, batch_query_embedding in enumerate(query_embedding):
recalled_idx, cosine_sims = final_index.knn_query(
batch_query_embedding.numpy(), recall_num)
batch_size = len(cosine_sims)
for row_index in range(batch_size):
text_index = batch_size * batch_index + row_index
for idx, doc_idx in enumerate(recalled_idx[row_index]):
f.write("{}\t{}\t{}\n".format(
text_list[text_index]["text"], id2corpus[doc_idx],
1.0 - cosine_sims[row_index][idx]))
text2similar = {}#相似文本对字典
with open(similar_text_pair_file, 'r', encoding='utf-8') as f:
for line in f:
text, similar_text = line.rstrip().split("\t")
text2similar[text] = similar_text
rs = []
with open(recall_result_file, 'r', encoding='utf-8') as f:
relevance_labels = []
for index, line in enumerate(f):
if index % recall_num == 0 and index != 0:#用来保存一个文本的召回文本的标记,召回到相似文本标记1
rs.append(relevance_labels)
relevance_labels = []
text, recalled_text, cosine_sim = line.rstrip().split("\t")
if text == recalled_text:
continue
if text2similar[text] == recalled_text:#找到相似的,设置标记1,表示找到
relevance_labels.append(1)
else:
relevance_labels.append(0)#不相同就设置0,表示没找到
recall_N = []
recall_num = [1, 5, 10, 20, 50]
for topN in recall_num:
R = round(100 * recall(rs, N=topN), 3)
recall_N.append(str(R))
# evaluate_result_file = os.path.join(recall_result_dir,
# evaluate_result)
# result = open(evaluate_result_file, 'a')
res = []
timestamp = time.strftime('%Y%m%d-%H%M%S', time.localtime())
res.append(timestamp)
for key, val in zip(recall_num, recall_N):
print('recall@{}={}'.format(key, val))
res.append(str(val))
# result.write('\t'.join(res) + '\n')
print(recall_N)
return float(recall_N[1])
paddle.set_device(device)
set_seed(seed)
train_ds = load_dataset(read_text_pair,
data_path=train_set_file,
lazy=False)
pretrained_model = AutoModel.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
trans_func = partial(convert_example,
tokenizer=tokenizer,
max_seq_length=max_seq_length)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
), # query_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
), # query_segment
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
), # title_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
), # tilte_segment
): [data for data in fn(samples)]
train_data_loader = create_dataloader(train_ds,
mode='train',
batch_size=batch_size,
batchify_fn=batchify_fn,
trans_fn=trans_func)
for i in train_data_loader:
print(i)
break
model = SemanticIndexBatchNeg(pretrained_model,
margin=margin,
scale=scale,
output_emb_size=output_emb_size)
if init_from_ckpt and os.path.isfile(init_from_ckpt):
state_dict = paddle.load(init_from_ckpt)
model.set_dict(state_dict)
print("warmup from:{}".format(init_from_ckpt))
model = paddle.DataParallel(model)
batchify_fn_dev = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'
), # text_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
), # text_segment
): [data for data in fn(samples)]
id2corpus = gen_id2corpus(corpus_file)
corpus_list = [{idx: text} for idx, text in id2corpus.items()]
corpus_ds = MapDataset(corpus_list)
corpus_data_loader = create_dataloader(corpus_ds,#索引库
mode='predict',
batch_size=batch_size,
batchify_fn=batchify_fn_dev,
trans_fn=trans_func)
text_list, text2similar_text = gen_text_file(similar_text_pair_file)
query_ds = MapDataset(text_list)
query_data_loader = create_dataloader(query_ds,#查询
mode='predict',
batch_size=batch_size,
batchify_fn=batchify_fn_dev,
trans_fn=trans_func)
if not os.path.exists(recall_result_dir):#召回结果文件
os.mkdir(recall_result_dir)
recall_result_file = os.path.join(recall_result_dir,
recall_result_file)
num_training_steps = len(train_data_loader) * epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps,
warmup_proportion)
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
def train(model,dataloader,optimizer,scheduler,global_step):
model.train()
total_loss=0.
for step, batch in enumerate(dataloader, start=1):
query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch
loss = model(
query_input_ids=query_input_ids,
title_input_ids=title_input_ids,
query_token_type_ids=query_token_type_ids,
title_token_type_ids=title_token_type_ids)
loss.backward()#反向传播
optimizer.step()#梯度更新
scheduler.step()#更新学习率
optimizer.clear_grad()#梯度归0
global_step += 1
total_loss += loss.item()
if global_step % 30 == 0 and global_step!=0:
print("global_step %d - batch: %d -loss: %.5f "
%(global_step,step,loss.item()))
avg_loss=total_loss/len(dataloader)
return avg_loss,global_step
def train_epochs(epochs,save_path):
global_step = 0#全局步
best_recall=0.0
for epoch in range(1,epochs+1):
avg_loss,global_step = train(model,train_data_loader,optimizer,\
lr_scheduler,global_step)
print("epoch:%d - global_step:%d - avg_loss: %.4f - best_recall: %.4f -lr:%.8f" \
% (epoch,global_step,avg_loss,best_recall,optimizer.get_lr()))
recall_5 = evaluate(model, corpus_data_loader, query_data_loader,
recall_result_file, text_list, id2corpus)
if recall_5>best_recall:
paddle.save(model.state_dict(),\
save_path+'yysy_in_batch_negative_recall_1.pdparams')
tokenizer.save_pretrained(save_path)
best_recall=recall_5
if not os.path.exists(save_dir): os.makedirs(save_dir)
train_epochs(epochs,save_dir)