mask language model 的具体实现及思路详解

文章介绍了如何使用PyTorch和HuggingFace的Transformers库实现BERT模型的maskedlanguagemodeling(MLM)训练,包括数据预处理、配置参数和训练过程。
摘要由CSDN通过智能技术生成
import os
import json
import copy
from tqdm.notebook import tqdm

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import BertForMaskedLM, BertTokenizerFast


class Config:
    def __init__(self):
        pass

    def mlm_config(
            self,
            mlm_probability=0.15,
            special_tokens_mask=None,
            prob_replace_mask=0.8,
            prob_replace_rand=0.1,
            prob_keep_ori=0.1,
    ):
        """
        :param mlm_probability: 被mask的token总数
        :param special_token_mask: 特殊token
        :param prob_replace_mask: 被替换成[MASK]的token比率
        :param prob_replace_rand: 被随机替换成其他token比率
        :param prob_keep_ori: 保留原token的比率
        """
        assert sum([prob_replace_mask, prob_replace_rand, prob_keep_ori]) == 1, ValueError(
            "Sum of the probs must equal to 1.")
        self.mlm_probability = mlm_probability
        self.special_tokens_mask = special_tokens_mask
        self.prob_replace_mask = prob_replace_mask
        self.prob_replace_rand = prob_replace_rand
        self.prob_keep_ori = prob_keep_ori

    def training_config(
            self,
            batch_size,
            epochs,
            learning_rate,
            weight_decay,
            device,
    ):
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.device = device

    def io_config(
            self,
            from_path,
            save_path,
    ):
        self.from_path = from_path
        self.save_path = save_path


class TrainDataset(Dataset):
    """
    注意:由于没有使用data_collator,batch放在dataset里边做,
    因而在dataloader出来的结果会多套一层batch维度,传入模型时注意squeeze掉
    """

    def __init__(self, input_texts, tokenizer, config):
        self.input_texts = input_texts
        self.tokenizer = tokenizer
        self.config = config
        self.ori_inputs = copy.deepcopy(input_texts)

    def __len__(self):
        return len(self.input_texts) // self.config.batch_size

    def __getitem__(self, idx):
        batch_text = self.input_texts[: self.config.batch_size]
        features = self.tokenizer(batch_text, max_length=512, truncation=True, padding=True, return_tensors='pt')
        inputs, labels = self.mask_tokens(features['input_ids'])#inputs为带有[mask]等替换的id,label为替换前的id,未替换的值用-100表示
        batch = {"inputs": inputs, "labels": labels}
        self.input_texts = self.input_texts[self.config.batch_size:]
        if not len(self):
            self.input_texts = self.ori_inputs

        return batch

    def mask_tokens(self, inputs):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.config.mlm_probability)#[4,9]所有值都为0.15
        if self.config.special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = self.config.special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()#矩阵中每一个值都以一定的概率变为1,同时1变为True
        labels[~masked_indices] = -100  # We only compute loss on masked tokens,False处变为-100

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(
            torch.full(labels.shape, self.config.prob_replace_mask)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)#将[mask]转换为id=103

        # 10% of the time, we replace masked input tokens with random word
        current_prob = self.config.prob_replace_rand / (1 - self.config.prob_replace_mask)
        indices_random = torch.bernoulli(
            torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        # a=random_words[indices_random]
        # b=inputs[indices_random]
        # print(a)
        # print(b)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels


def train(model, train_dataloader, config):
    """
    训练
    :param model: nn.Module
    :param train_dataloader: DataLoader
    :param config: Config
    ---------------
    ver: 2021-11-08
    by: changhongyu
    """
    assert config.device.startswith('cuda') or config.device == 'cpu', ValueError("Invalid device.")
    device = torch.device(config.device)

    model.to(device)

    if not len(train_dataloader):
        raise EOFError("Empty train_dataloader.")

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]

    optimizer = AdamW(params=optimizer_grouped_parameters, lr=config.learning_rate, weight_decay=config.weight_decay)

    for cur_epc in tqdm(range(int(config.epochs)), desc="Epoch"):
        training_loss = 0
        print("Epoch: {}".format(cur_epc + 1))
        model.train()
        for step, batch in enumerate(tqdm(train_dataloader, desc='Step')):
            input_ids = batch['inputs'].squeeze(0).to(device)#[b,text_length]包含替换[mask]后,所有词的id
            labels = batch['labels'].squeeze(0).to(device)#[b,text_length],替换处id保留,未替换出id变为-100
            result = model(input_ids=input_ids, labels=labels)#logits,[b,text_length,vocab_size]
            loss = model(input_ids=input_ids, labels=labels).loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.zero_grad()
            training_loss += loss.item()
        print("Training loss: ", training_loss)



if __name__ == '__main__':
    config = Config()
    config.mlm_config()
    config.training_config(batch_size=4, epochs=10, learning_rate=1e-5, weight_decay=0, device='cuda:0')
    config.io_config(from_path='/root/autodl-tmp/bert-base-chinese',
                     save_path='mlm')
    bert_tokenizer = BertTokenizerFast.from_pretrained(config.from_path)
    bert_mlm_model = BertForMaskedLM.from_pretrained(config.from_path)
    training_texts = [
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
        "这是一条文本",
        "这是另一条文本",
    ]
    train_dataset = TrainDataset(training_texts, bert_tokenizer, config)
    train_dataloader = DataLoader(train_dataset)
    train(model=bert_mlm_model, train_dataloader=train_dataloader, config=config)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值