以下是我实现的代码,不对之处,恳请指正。
import nlp_basictasks
import os,json
import numpy as np
import torch
import torch.nn as nn
import random
from tqdm.autonotebook import tqdm, trange
from torch.utils.data import DataLoader
from nlp_basictasks.modules import SBERT
from nlp_basictasks.modules.transformers import BertTokenizer,BertModel,BertConfig
from nlp_basictasks.readers.sts import InputExample,convert_examples_to_features,getExamples,convert_sentences_to_features
from nlp_basictasks.modules.utils import get_optimizer,get_scheduler
from nlp_basictasks.Trainer import Trainer
from nlp_basictasks.evaluation import stsEvaluator
from sentence_transformers import SentenceTransformer,models
model_path='chinese-roberta-wwm/'
tokenizer=BertTokenizer.from_pretrained(model_path)
max_seq_len=64
batch_size=128
#数据集来源:https://github.com/pluto-junzeng/CNSD
train_file='cnsd-sts-train.txt'
dev_file='cnsd-sts-dev.txt'
test_file='cnsd-sts-test.txt'
def read_data(file_path):
sentences=[]
labels=[]
with open(file_path) as f:
lines=f.readlines()
for line in lines:
line_split=line.strip().split('||')
sentences.append([line_split[1],line_split[2]])
labels.append(line_split[3])
return sentences,labels
train_sentences,train_labels=read_data(train_file)
dev_sentences,dev_labels=read_data(dev_file)
test_sentences,test_labels=read_data(test_file)
print(train_sentences[:2],train_labels[:2])
print(dev_sentences[:2],dev_labels[:2])
print(test_sentences[:2],test_labels[:2])
train_sentences=[sentence[0] for sentence in train_sentences]#无监督形式
print(len(train_sentences))
print(train_sentences[:3])
train_examples=[InputExample(text_list=[sentence,sentence],label=1) for sentence in train_sentences]
train_dataloader=DataLoader(train_examples,shuffle=True,batch_size=batch_size)
def smart_batching_collate(batch):
features_of_a,features_of_b,labels=convert_examples_to_features(examples=batch,tokenizer=tokenizer,max_seq_len=max_seq_len)
return features_of_a,features_of_b,labels
train_dataloader.collate_fn=smart_batching_collate
print(train_examples[0])
#dev_sentences=[example.text_list for example in dev_examples]
#dev_labels=[example.label for example in dev_examples]
print(dev_sentences[0],dev_labels[0])
sentences1_list=[sen[0] for sen in dev_sentences]
sentences2_list=[sen[1] for sen in dev_sentences]
dev_labels=[int(score) for score in dev_labels]
evaluator=stsEvaluator(sentences1=sentences1_list,sentences2=sentences2_list,batch_size=64,write_csv=True,scores=dev_labels)
from queue import Queue
class ESimCSE(nn.Module):
def __init__(self,
bert_model_path,
q_size=256,
dup_rate=0.32,
is_sbert_model=True,
temperature=0.05,
is_distilbert=False,
gamma=0.99,
device='cpu'):
super(ESimCSE,self).__init__()
if is_sbert_model:
self.encoder=SentenceTransformer(model_name_or_path=bert_model_path,device=device)
self.moco_encoder=SentenceTransformer(model_name_or_path=bert_model_path,device=device)
else:
word_embedding_model = models.Transformer(bert_model_path, max_seq_length=max_seq_len)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
self.encoder=SentenceTransformer(modules=[word_embedding_model, pooling_model],device=device)
self.moco_encoder=SentenceTransformer(modules=[word_embedding_model, pooling_model],device=device)
self.gamma=gamma
self.q=[]
self.q_size=q_size
self.dup_rate=dup_rate
self.temperature=temperature
self.is_distilbert=is_distilbert#蒸馏版本的BERT不支持token_type_ids
def cal_cos_sim(self,embeddings1,embeddings2):
embeddings1_norm=torch.nn.functional.normalize(embeddings1,p=2,dim=1)
embeddings2_norm=torch.nn.functional.normalize(embeddings2,p=2,dim=1)
return torch.mm(embeddings1_norm,embeddings2_norm.transpose(0,1))#(batch_size,batch_size)
def word_repetition(self,sentence_feature):
input_ids, attention_mask, token_type_ids=sentence_feature['input_ids'].cpu().tolist(),sentence_feature['attention_mask'].cpu().tolist(),sentence_feature['token_type_ids'].cpu().tolist()
bsz, seq_len = len(input_ids),len(input_ids[0])
#print(bsz,seq_len)
repetitied_input_ids=[]
repetitied_attention_mask=[]
repetitied_token_type_ids=[]
rep_seq_len=seq_len
for bsz_id in range(bsz):
sample_mask = attention_mask[bsz_id]
actual_len = sum(sample_mask)
cur_input_id=input_ids[bsz_id]
dup_len=random.randint(a=0,b=max(2,int(self.dup_rate*actual_len)))
dup_word_index=random.sample(list(range(1,actual_len)),k=dup_len)
r_input_id=[]
r_attention_mask=[]
r_token_type_ids=[]
for index,word_id in enumerate(cur_input_id):
if index in dup_word_index:
r_input_id.append(word_id)
r_attention_mask.append(sample_mask[index])
r_token_type_ids.append(token_type_ids[bsz_id][index])
r_input_id.append(word_id)
r_attention_mask.append(sample_mask[index])
r_token_type_ids.append(token_type_ids[bsz_id][index])
after_dup_len=len(r_input_id)
#assert after_dup_len==actual_len+dup_len
repetitied_input_ids.append(r_input_id)#+rest_input_ids)
repetitied_attention_mask.append(r_attention_mask)#+rest_attention_mask)
repetitied_token_type_ids.append(r_token_type_ids)#+rest_token_type_ids)
assert after_dup_len==dup_len+seq_len
if after_dup_len>rep_seq_len:
rep_seq_len=after_dup_len
for i in range(bsz):
after_dup_len=len(repetitied_input_ids[i])
pad_len=rep_seq_len-after_dup_len
repetitied_input_ids[i]+=[0]*pad_len
repetitied_attention_mask[i]+=[0]*pad_len
repetitied_token_type_ids[i]+=[0]*pad_len
repetitied_input_ids=torch.LongTensor(repetitied_input_ids)
repetitied_attention_mask=torch.LongTensor(repetitied_attention_mask)
repetitied_token_type_ids=torch.LongTensor(repetitied_token_type_ids)
return {"input_ids":repetitied_input_ids,'attention_mask':repetitied_attention_mask,'token_type_ids':repetitied_token_type_ids}
def forward(self,batch_inputs):
'''
为了实现兼容,所有model的batch_inputs最后一个位置必须是labels,即使为None
get token_embeddings,cls_token_embeddings,sentence_embeddings
sentence_embeddings是经过Pooling层后concat的embedding。维度=768*k,其中k取决于pooling的策略
一般来讲,只会取一种pooling策略,要么直接cls要么mean last or mean last2 or mean first and last layer,所以sentence_embeddings的维度也是768
'''
batch1_features,batch2_features,_=batch_inputs
if self.is_distilbert:
del batch1_features['token_type_ids']
del batch2_features['token_type_ids']
batch1_features={key:value.to(self.encoder.device) for key, value in batch1_features.items()}
batch1_embeddings=self.encoder(batch1_features)['sentence_embedding']
batch2_features=self.word_repetition(sentence_feature=batch2_features)
batch2_features={key:value.to(self.encoder.device) for key, value in batch2_features.items()}
batch2_embeddings=self.encoder(batch2_features)['sentence_embedding']
cos_sim=self.cal_cos_sim(batch1_embeddings,batch2_embeddings)/self.temperature#(batch_size,batch_size)
batch_size=cos_sim.size(0)
assert cos_sim.size()==(batch_size,batch_size)
labels=torch.arange(batch_size).to(cos_sim.device)
negative_samples=None
if len(self.q)>0:
negative_samples=torch.vstack(self.q[:self.q_size])#(q_size,768)
if len(self.q)+batch_size>=self.q_size:
del self.q[:batch_size]
with torch.no_grad():
self.moco_encoder[0].auto_model.encoder.config.attention_probs_dropout_prob=0.0
self.moco_encoder[0].auto_model.encoder.config.hidden_dropout_prob=0.0
self.q.extend(self.moco_encoder(batch1_features)['sentence_embedding'])
if negative_samples is not None:
batch_size+=negative_samples.size(0)#(N+M)
cos_sim_with_neg=self.cal_cos_sim(batch1_embeddings,negative_samples)/self.temperature#(N,M) not (N,N) N is bsz
cos_sim=torch.cat([cos_sim,cos_sim_with_neg],dim=1)#(N,N+M)
#labels=
for encoder_param,moco_encoder_param in zip(self.encoder.parameters(),self.moco_encoder.parameters()):
moco_encoder_param.data=self.gamma*moco_encoder_param.data+(1.-self.gamma)*encoder_param.data
return nn.CrossEntropyLoss()(cos_sim,labels)
def encode(self, sentences,
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = 'sentence_embedding',
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False):
'''
传进来的sentences只能是single_batch
'''
return self.encoder.encode(sentences=sentences,
batch_size=batch_size,
show_progress_bar=show_progress_bar,
output_value=output_value,
convert_to_numpy=convert_to_numpy,
convert_to_tensor=convert_to_tensor,
device=device,
normalize_embeddings=normalize_embeddings)
def save(self,output_path):
os.makedirs(output_path,exist_ok=True)
with open(os.path.join(output_path, 'model_param_config.json'), 'w') as fOut:
json.dump(self.get_config_dict(output_path), fOut)
self.encoder.save(output_path)
def get_config_dict(self,output_path):
'''
一定要有dict,这样才能初始化Model
'''
return {'bert_model_path':output_path,'temperature': self.temperature, 'is_distilbert': self.is_distilbert,
'q_size':self.q_size,'dup_rate':self.dup_rate,'gamma':self.gamma}
@staticmethod
def load(input_path):
with open(os.path.join(input_path, 'model_param_config.json')) as fIn:
config = json.load(fIn)
return ESimCSE(**config)
device='cpu'
esimcse=ESimCSE(bert_model_path=model_path,
is_distilbert=False,
is_sbert_model=False,
dup_rate=0.32,gamma=0.99,
device=device)
evaluator(esimcse)
epochs=5
output_path='定义想要模型保存的路径'
tensorboard_logdir=os.path.join(output_path,'log')
optimizer_type='AdamW'
scheduler='WarmupLinear'
warmup_proportion=0.1
optimizer_params={'lr': 2e-5}
weight_decay=0.01
num_train_steps = int(len(train_dataloader) * epochs)
warmup_steps = num_train_steps*warmup_proportion
optimizer = get_optimizer(model=esimcse,optimizer_type=optimizer_type,weight_decay=weight_decay,optimizer_params=optimizer_params)
scheduler = get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)
trainer=Trainer(epochs=epochs,output_path=output_path,tensorboard_logdir=tensorboard_logdir,early_stop_patience=20)
trainer.train(train_dataloader=train_dataloader,
model=esimcse,
optimizer=optimizer,
scheduler=scheduler,
evaluator=evaluator,
)
数据集来源:STS-B
所有的实验代码可以参考复现SimCSE、ESimCSE等论文