复现ESimCSE论文

以下是我实现的代码,不对之处,恳请指正。

import nlp_basictasks
import os,json
import numpy as np
import torch
import torch.nn as nn
import random
from tqdm.autonotebook import tqdm, trange
from torch.utils.data import DataLoader
from nlp_basictasks.modules import SBERT
from nlp_basictasks.modules.transformers import BertTokenizer,BertModel,BertConfig
from nlp_basictasks.readers.sts import InputExample,convert_examples_to_features,getExamples,convert_sentences_to_features
from nlp_basictasks.modules.utils import get_optimizer,get_scheduler
from nlp_basictasks.Trainer import Trainer
from nlp_basictasks.evaluation import stsEvaluator
from sentence_transformers import SentenceTransformer,models
model_path='chinese-roberta-wwm/'
tokenizer=BertTokenizer.from_pretrained(model_path)
max_seq_len=64
batch_size=128

#数据集来源:https://github.com/pluto-junzeng/CNSD
train_file='cnsd-sts-train.txt'
dev_file='cnsd-sts-dev.txt'
test_file='cnsd-sts-test.txt'
def read_data(file_path):
    sentences=[]
    labels=[]
    with open(file_path) as f:
        lines=f.readlines()
    for line in lines:
        line_split=line.strip().split('||')
        sentences.append([line_split[1],line_split[2]])
        labels.append(line_split[3])
    return sentences,labels


train_sentences,train_labels=read_data(train_file)
dev_sentences,dev_labels=read_data(dev_file)
test_sentences,test_labels=read_data(test_file)
print(train_sentences[:2],train_labels[:2])
print(dev_sentences[:2],dev_labels[:2])
print(test_sentences[:2],test_labels[:2])

train_sentences=[sentence[0] for sentence in train_sentences]#无监督形式
print(len(train_sentences))
print(train_sentences[:3])
train_examples=[InputExample(text_list=[sentence,sentence],label=1) for sentence in train_sentences]
train_dataloader=DataLoader(train_examples,shuffle=True,batch_size=batch_size)
def smart_batching_collate(batch):
    features_of_a,features_of_b,labels=convert_examples_to_features(examples=batch,tokenizer=tokenizer,max_seq_len=max_seq_len)
    return features_of_a,features_of_b,labels
train_dataloader.collate_fn=smart_batching_collate
print(train_examples[0])

#dev_sentences=[example.text_list for example in dev_examples]
#dev_labels=[example.label for example in dev_examples]
print(dev_sentences[0],dev_labels[0])
sentences1_list=[sen[0] for sen in dev_sentences]
sentences2_list=[sen[1] for sen in dev_sentences]
dev_labels=[int(score) for score in dev_labels]
evaluator=stsEvaluator(sentences1=sentences1_list,sentences2=sentences2_list,batch_size=64,write_csv=True,scores=dev_labels)


from queue import Queue
class ESimCSE(nn.Module):
    def __init__(self,
                 bert_model_path,
                 q_size=256,
                 dup_rate=0.32,
                 is_sbert_model=True,
                temperature=0.05,
                is_distilbert=False,
                 gamma=0.99,
                device='cpu'):
        super(ESimCSE,self).__init__()
        if is_sbert_model:
            self.encoder=SentenceTransformer(model_name_or_path=bert_model_path,device=device)
            self.moco_encoder=SentenceTransformer(model_name_or_path=bert_model_path,device=device)
        else:
            word_embedding_model = models.Transformer(bert_model_path, max_seq_length=max_seq_len)
            pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
            self.encoder=SentenceTransformer(modules=[word_embedding_model, pooling_model],device=device)
            self.moco_encoder=SentenceTransformer(modules=[word_embedding_model, pooling_model],device=device)
        self.gamma=gamma
        self.q=[]
        self.q_size=q_size
        self.dup_rate=dup_rate
        self.temperature=temperature
        self.is_distilbert=is_distilbert#蒸馏版本的BERT不支持token_type_ids
    def cal_cos_sim(self,embeddings1,embeddings2):
        embeddings1_norm=torch.nn.functional.normalize(embeddings1,p=2,dim=1)
        embeddings2_norm=torch.nn.functional.normalize(embeddings2,p=2,dim=1)
        return torch.mm(embeddings1_norm,embeddings2_norm.transpose(0,1))#(batch_size,batch_size)

    def word_repetition(self,sentence_feature):
        input_ids, attention_mask, token_type_ids=sentence_feature['input_ids'].cpu().tolist(),sentence_feature['attention_mask'].cpu().tolist(),sentence_feature['token_type_ids'].cpu().tolist()
        bsz, seq_len = len(input_ids),len(input_ids[0])
        #print(bsz,seq_len)
        repetitied_input_ids=[]
        repetitied_attention_mask=[]
        repetitied_token_type_ids=[]
        rep_seq_len=seq_len
        for bsz_id in range(bsz):
            sample_mask = attention_mask[bsz_id]
            actual_len = sum(sample_mask)

            cur_input_id=input_ids[bsz_id]
            dup_len=random.randint(a=0,b=max(2,int(self.dup_rate*actual_len)))
            dup_word_index=random.sample(list(range(1,actual_len)),k=dup_len)
            
            r_input_id=[]
            r_attention_mask=[]
            r_token_type_ids=[]
            for index,word_id in enumerate(cur_input_id):
                if index in dup_word_index:
                    r_input_id.append(word_id)
                    r_attention_mask.append(sample_mask[index])
                    r_token_type_ids.append(token_type_ids[bsz_id][index])

                r_input_id.append(word_id)
                r_attention_mask.append(sample_mask[index])
                r_token_type_ids.append(token_type_ids[bsz_id][index])

            after_dup_len=len(r_input_id)
            #assert after_dup_len==actual_len+dup_len
            repetitied_input_ids.append(r_input_id)#+rest_input_ids)
            repetitied_attention_mask.append(r_attention_mask)#+rest_attention_mask)
            repetitied_token_type_ids.append(r_token_type_ids)#+rest_token_type_ids)

            assert after_dup_len==dup_len+seq_len
            if after_dup_len>rep_seq_len:
                rep_seq_len=after_dup_len

        for i in range(bsz):
            after_dup_len=len(repetitied_input_ids[i])
            pad_len=rep_seq_len-after_dup_len
            repetitied_input_ids[i]+=[0]*pad_len
            repetitied_attention_mask[i]+=[0]*pad_len
            repetitied_token_type_ids[i]+=[0]*pad_len

        repetitied_input_ids=torch.LongTensor(repetitied_input_ids)
        repetitied_attention_mask=torch.LongTensor(repetitied_attention_mask)
        repetitied_token_type_ids=torch.LongTensor(repetitied_token_type_ids)
        return {"input_ids":repetitied_input_ids,'attention_mask':repetitied_attention_mask,'token_type_ids':repetitied_token_type_ids}

    def forward(self,batch_inputs):
        '''
        为了实现兼容,所有model的batch_inputs最后一个位置必须是labels,即使为None
        get token_embeddings,cls_token_embeddings,sentence_embeddings
        sentence_embeddings是经过Pooling层后concat的embedding。维度=768*k,其中k取决于pooling的策略
        一般来讲,只会取一种pooling策略,要么直接cls要么mean last or mean last2 or mean first and last layer,所以sentence_embeddings的维度也是768
        '''
        batch1_features,batch2_features,_=batch_inputs
        if self.is_distilbert:
            del batch1_features['token_type_ids']
            del batch2_features['token_type_ids']
        batch1_features={key:value.to(self.encoder.device) for key, value in batch1_features.items()}
        batch1_embeddings=self.encoder(batch1_features)['sentence_embedding']
        batch2_features=self.word_repetition(sentence_feature=batch2_features)
        batch2_features={key:value.to(self.encoder.device) for key, value in batch2_features.items()}
        
        batch2_embeddings=self.encoder(batch2_features)['sentence_embedding']
        cos_sim=self.cal_cos_sim(batch1_embeddings,batch2_embeddings)/self.temperature#(batch_size,batch_size)
        batch_size=cos_sim.size(0)
        assert cos_sim.size()==(batch_size,batch_size)
        labels=torch.arange(batch_size).to(cos_sim.device)
        negative_samples=None
        if len(self.q)>0:
            negative_samples=torch.vstack(self.q[:self.q_size])#(q_size,768)
        if len(self.q)+batch_size>=self.q_size:
            del self.q[:batch_size]
            
        with torch.no_grad():
            self.moco_encoder[0].auto_model.encoder.config.attention_probs_dropout_prob=0.0
            self.moco_encoder[0].auto_model.encoder.config.hidden_dropout_prob=0.0
            self.q.extend(self.moco_encoder(batch1_features)['sentence_embedding'])
            
        if negative_samples is not None:
            batch_size+=negative_samples.size(0)#(N+M)
            cos_sim_with_neg=self.cal_cos_sim(batch1_embeddings,negative_samples)/self.temperature#(N,M) not (N,N) N is bsz
            cos_sim=torch.cat([cos_sim,cos_sim_with_neg],dim=1)#(N,N+M)
            #labels=
        for encoder_param,moco_encoder_param in zip(self.encoder.parameters(),self.moco_encoder.parameters()):
            moco_encoder_param.data=self.gamma*moco_encoder_param.data+(1.-self.gamma)*encoder_param.data
            
        return nn.CrossEntropyLoss()(cos_sim,labels)
    
    def encode(self, sentences,
               batch_size: int = 32,
               show_progress_bar: bool = None,
               output_value: str = 'sentence_embedding',
               convert_to_numpy: bool = True,
               convert_to_tensor: bool = False,
               device: str = None,
               normalize_embeddings: bool = False):
        '''
        传进来的sentences只能是single_batch
        '''
        return self.encoder.encode(sentences=sentences,
                                         batch_size=batch_size,
                                         show_progress_bar=show_progress_bar,
                                         output_value=output_value,
                                         convert_to_numpy=convert_to_numpy,
                                         convert_to_tensor=convert_to_tensor,
                                         device=device,
                                         normalize_embeddings=normalize_embeddings)
    
    def save(self,output_path):
        os.makedirs(output_path,exist_ok=True)
        with open(os.path.join(output_path, 'model_param_config.json'), 'w') as fOut:
            json.dump(self.get_config_dict(output_path), fOut)
        self.encoder.save(output_path)
        
    def get_config_dict(self,output_path):
        '''
        一定要有dict,这样才能初始化Model
        '''
        return {'bert_model_path':output_path,'temperature': self.temperature, 'is_distilbert': self.is_distilbert,
                'q_size':self.q_size,'dup_rate':self.dup_rate,'gamma':self.gamma}
    @staticmethod
    def load(input_path):
        with open(os.path.join(input_path, 'model_param_config.json')) as fIn:
            config = json.load(fIn)
        return ESimCSE(**config)

device='cpu'
esimcse=ESimCSE(bert_model_path=model_path,
                is_distilbert=False,
                is_sbert_model=False,
                dup_rate=0.32,gamma=0.99,
                device=device)

evaluator(esimcse)

epochs=5
output_path='定义想要模型保存的路径'
tensorboard_logdir=os.path.join(output_path,'log')

optimizer_type='AdamW'
scheduler='WarmupLinear'
warmup_proportion=0.1
optimizer_params={'lr': 2e-5}
weight_decay=0.01
num_train_steps = int(len(train_dataloader) * epochs)
warmup_steps = num_train_steps*warmup_proportion
optimizer = get_optimizer(model=esimcse,optimizer_type=optimizer_type,weight_decay=weight_decay,optimizer_params=optimizer_params)
scheduler = get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)


trainer=Trainer(epochs=epochs,output_path=output_path,tensorboard_logdir=tensorboard_logdir,early_stop_patience=20)
trainer.train(train_dataloader=train_dataloader,
             model=esimcse,
             optimizer=optimizer,
             scheduler=scheduler,
             evaluator=evaluator,
             )

数据集来源:STS-B
在这里插入图片描述
所有的实验代码可以参考复现SimCSE、ESimCSE等论文

  • 5
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值