【自然语言处理】【Pytorch】从头实现SimCSE

本文介绍了一个基于对比学习的句子向量表示模型SimCSE的实现过程。从定义参数到读取数据,再到构建Dataset和collate_fn,最终完成模型训练并保存。涉及PyTorch、Transformers等库的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

相关博客:
【自然语言处理】【对比学习】SimCSE:基于对比学习的句向量表示
【自然语言处理】BERT-Whitening
【自然语言处理】【Pytorch】从头实现SimCSE
【自然语言处理】【向量检索】面向开发域稠密检索的多视角文档表示学习
【自然语言处理】【向量表示】AugSBERT:改善用于成对句子评分任务的Bi-Encoders的数据增强方法
【自然语言处理】【向量表示】PairSupCon:用于句子表示的成对监督对比学习

github:

Concise_SimCSE

博客内容:基于pytorch和transformers,从头开始实现SimCSE
import torch
import torch.nn as nn

from abc import ABC
from tqdm.notebook import tqdm
from dataclasses import dataclass, field
from typing import List, Union, Optional, Dict
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, TrainingArguments, Trainer
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions

一、定义参数

@dataclass
class DataArguments:
    train_file: str = field(default="./data/simcse/wiki1m_for_simcse.txt",
                            metadata={"help": "The path of train file"})
    model_name_or_path: str = field(default="E:/pretrained/bert-base-uncased",
                                    metadata={"help": "The name or path of pre-trained language model"})
    max_seq_length: int = field(default=32,
                                metadata={"help": "The maximum total input sequence length after tokenization."})


training_args = TrainingArguments(
        output_dir="./checkpoints",
        num_train_epochs=1,
        per_device_train_batch_size=64,
        learning_rate=3e-5,
        load_best_model_at_end=True,
        overwrite_output_dir=True,
        do_train=True,
        do_eval=False,
        logging_steps=10)


data_args = DataArguments()

二、读取数据

# 初始化tokenizer
tokenizer = BertTokenizer.from_pretrained(data_args.model_name_or_path)
# 读取训练数据
with open(data_args.train_file, encoding="utf8") as file:
    texts = [line.strip() for line in tqdm(file.readlines())]
print(type(texts))
print(texts[0])
<class 'list'>
YMCA in South Australia

三、构建Dataset和collate_fn

3.1 构建Dataset

class PairDataset(Dataset):
    def __init__(self, examples: List[str]):
        total = len(examples)
        # 将所有样本复制一份用于对比学习
        sentences_pair = examples + examples
        sent_features = tokenizer(sentences_pair,
                                  max_length=data_args.max_seq_length,
                                  truncation=True,
                                  padding=False)
        features = {}
        # 将相同的样本放在同一个列表中
        for key in sent_features:
            features[key] = [[sent_features[key][i], sent_features[key][i + total]] for i in tqdm(range(total))]
        self.input_ids = features["input_ids"]
        self.attention_mask = features["attention_mask"]
        self.token_type_ids = features["token_type_ids"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, item):
        return {
            "input_ids": self.input_ids[item],
            "attention_mask": self.attention_mask[item],
            "token_type_ids": self.token_type_ids[item]
        }

train_dataset = PairDataset(texts)
print(train_dataset[0])
{'input_ids': [[101, 26866, 1999, 2148, 2660, 102], [101, 26866, 1999, 2148, 2660, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]}

3.2 构建collate_fn

@dataclass
class DataCollator:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        special_keys = ['input_ids', 'attention_mask', 'token_type_ids']
        batch_size = len(features)
        if batch_size == 0:
            return
        # flat_features: [sen1, sen1, sen2, sen2, ...]
        flat_features = []
        for feature in features:
            for i in range(2):
                flat_features.append({k: feature[k][i] for k in feature.keys() if k in special_keys})
        # padding
        batch = self.tokenizer.pad(
            flat_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        # batch_size, 2, seq_len
        batch = {k: batch[k].view(batch_size, 2, -1) for k in batch if k in special_keys}
        return batch

collate_fn = DataCollator(tokenizer)
dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=collate_fn)
batch = next(iter(dataloader))
print(batch.keys())
print(batch["input_ids"].shape)
dict_keys(['input_ids', 'attention_mask', 'token_type_ids'])
torch.Size([4, 2, 32])

四、构建模型

# 全连接层,用于投影CLS的向量表示
class MLPLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size)
        self.activation = nn.Tanh()

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = self.activation(x)
        return x

# 相似度层,计算向量间相似度
class Similarity(nn.Module):
    def __init__(self, temp):
        super().__init__()
        self.temp = temp
        self.cos = nn.CosineSimilarity(dim=-1)

    def forward(self, x, y):
        return self.cos(x, y) / self.temp

    
# SimCSE的完整模型结构
class BertForCL(BertPreTrainedModel, ABC):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.mlp = MLPLayer(config.hidden_size, config.hidden_size)
        self.sim = Similarity(temp=0.05)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                sent_emb=False):
        if sent_emb:
            # 模型推断时使用的forward
            return self.sentemb_forward(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        token_type_ids=token_type_ids,
                                        position_ids=position_ids,
                                        head_mask=head_mask,
                                        inputs_embeds=inputs_embeds,
                                        labels=labels,
                                        output_attentions=output_attentions,
                                        output_hidden_states=output_hidden_states,
                                        return_dict=return_dict)
        else:
            # 模型训练时使用的forward
            return self.cl_forward(input_ids=input_ids,
                                   attention_mask=attention_mask,
                                   token_type_ids=token_type_ids,
                                   position_ids=position_ids,
                                   head_mask=head_mask,
                                   inputs_embeds=inputs_embeds,
                                   labels=labels,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states,
                                   return_dict=return_dict)

    def sentemb_forward(self,
                        input_ids=None,
                        attention_mask=None,
                        token_type_ids=None,
                        position_ids=None,
                        head_mask=None,
                        inputs_embeds=None,
                        labels=None,
                        output_attentions=None,
                        output_hidden_states=None,
                        return_dict=None):
        # 1.使用bert进行编码
        outputs = self.bert(input_ids, attention_mask=attention_mask, return_dict=True)
        # 2.取cls的表示
        cls_output = outputs.last_hidden_state[:, 0]
        # 3.使用MLP进行投影
        cls_output = self.mlp(cls_output)
        # 返回
        if not return_dict:
            return (outputs[0], cls_output) + outputs[2:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            pooler_output=cls_output,
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
        )

    def cl_forward(self,
                   input_ids=None,
                   attention_mask=None,
                   token_type_ids=None,
                   position_ids=None,
                   head_mask=None,
                   inputs_embeds=None,
                   labels=None,
                   output_attentions=None,
                   output_hidden_states=None,
                   return_dict=None):
        # input_ids: batch_size, num_sent, len
        batch_size = input_ids.size(0)
        num_sent = input_ids.size(1)  # 2
        # 1. 重塑输入张量的形状,使其满足bert对输入的要求
        # input_ids: batch_size * num_sent, len
        input_ids = input_ids.view((-1, input_ids.size(-1)))
        attention_mask = attention_mask.view((-1, attention_mask.size(-1)))
        # 2. 使用bert进行编码
        outputs = self.bert(input_ids, attention_mask=attention_mask, return_dict=True)
        # 3. 取cls的向量表示
        cls_output = outputs.last_hidden_state[:, 0]
        # 4. 重塑形状
        cls_output = cls_output.view((batch_size, num_sent, cls_output.size(-1)))
        # 5. 全连接层投影
        # batch_size, num_sent, 768
        cls_output = self.mlp(cls_output)
        # 6. 将同一批样本的两次向量表示分开
        z1, z2 = cls_output[:, 0], cls_output[:, 1]
        # 7. 计算两两相似度,得到相似度矩阵cos_sim
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0))
        # 8. 生成标签[0,1,...,batch_size-1],该标签用于提高相似度句子cos_sim对角线,并降低非对角线
        labels = torch.arange(cos_sim.size(0)).long().to(self.device)
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(cos_sim, labels)

        if not return_dict:
            output = (cos_sim,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=cos_sim,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

model = BertForCL.from_pretrained(data_args.model_name_or_path)
cl_out = model(**batch, return_dict=True)
print(cl_out.keys())
odict_keys(['loss', 'logits'])

五、模型训练

model.resize_token_embeddings(len(tokenizer))
trainer = Trainer(model=model,
                  train_dataset=train_dataset,
                  args=training_args,
                  tokenizer=tokenizer,
                  data_collator=collate_fn)
trainer.train()
trainer.save_model("models/test")
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BQW_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值