预训练语言模型复现-2 whole word mask

看文章标题,mask可以定位到mask language model,代表模型是bert一系列的成果。

mask可以分为token mask和whole Word mask,怎么实现?
两者的区别是什么?
整个实现过程可以借鉴transformer的源码。

whole Word mask

添加链接描述

@dataclass
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
    """
    Data collator used for language modeling.

    - collates batches of tensors, honoring their tokenizer's pad_token
    - preprocesses batches for masked language modeling
    """

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        if isinstance(examples[0], (dict, BatchEncoding)):
            input_ids = [e["input_ids"] for e in examples]
        else:
            input_ids = examples
            examples = [{"input_ids": e} for e in examples]

        batch_input = _collate_batch(input_ids, self.tokenizer)

        mask_labels = []
        for e in examples:
            ref_tokens = []
            for id in tolist(e["input_ids"]):
                token = self.tokenizer._convert_id_to_token(id)
                ref_tokens.append(token)

            # For Chinese tokens, we need extra inf to mark sub-word, e.g [,]-> [喜,##]
            if "chinese_ref" in e:
                ref_pos = tolist(e["chinese_ref"])
                len_seq = len(e["input_ids"])
                for i in range(len_seq):
                    if i in ref_pos:
                        ref_tokens[i] = "##" + ref_tokens[i]
            mask_labels.append(self._whole_word_mask(ref_tokens))
        batch_mask = _collate_batch(mask_labels, self.tokenizer)
        inputs, labels = self.mask_tokens(batch_input, batch_mask)
        return {"input_ids": inputs, "labels": labels}

    def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
        """
        Get 0/1 labels for masked tokens with whole word mask proxy
        """

        cand_indexes = []
        for (i, token) in enumerate(input_tokens):
            if token == "[CLS]" or token == "[SEP]":
                continue

            if len(cand_indexes) >= 1 and token.startswith("##"):
                cand_indexes[-1].append(i)
            else:
                cand_indexes.append([i])

        random.shuffle(cand_indexes)
        num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
        masked_lms = []
        covered_indexes = set()
        for index_set in cand_indexes:
            if len(masked_lms) >= num_to_predict:
                break
            # If adding a whole-word mask would exceed the maximum number of
            # predictions, then just skip this candidate.
            if len(masked_lms) + len(index_set) > num_to_predict:
                continue
            is_any_index_covered = False
            for index in index_set:
                if index in covered_indexes:
                    is_any_index_covered = True
                    break
            if is_any_index_covered:
                continue
            for index in index_set:
                covered_indexes.add(index)
                masked_lms.append(index)

        assert len(covered_indexes) == len(masked_lms)
        mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
        return mask_labels

中文实现whole Word mask

参考原文

# coding=utf-8
'''
    Whole word mask for bert
'''
import pkuseg
from transformers import BertConfig, BertForMaskedLM, DataCollatorForWholeWordMask,\
    BertTokenizer, TrainingArguments, Trainer
from torch.utils.data import Dataset
from tqdm import tqdm
import torch

class My_wwm_pretrain_dataset(Dataset):

    def __init__(self, path, tokenizer, dup_factor=5,max_length=512): # dup_factor : dynamic mask for 5 times
        self.examples = []
        with open(path,'r',encoding='utf-8') as f:
            total_data = f.readlines()
            with tqdm(total_data * dup_factor) as loader:
                for data in loader:
                    # clean data
                    data = data.replace('\n', '').replace('\r', '').replace('\t','').replace(' ','').replace(' ', '')
                    chinese_ref = self.get_new_segment(data)
                    input_ids = tokenizer.encode_plus(data,truncation=True,max_length=max_length).input_ids
                    dict_data = {'input_ids' : input_ids, 'chinese_ref' : chinese_ref}
                    self.examples.append(dict_data)
                    loader.set_description(f'loading data')

    def get_new_segment(self,segment):
        """
            使用分词工具获取 whole word mask
            e.g [,]-> [喜,##欢]
        """
        seq_cws = seg.cut("".join(segment))  # 利用 pkuseg 进行医学领域分词
        chinese_ref = []
        index = 1
        for seq in seq_cws:
            for i, word in enumerate(seq):
                if i>0:
                    chinese_ref.append(index)
                index +=1
        return chinese_ref

    def __getitem__(self, index):
        return self.examples[index]

    def __len__(self):
        return len(self.examples)


if __name__ == '__main__':
    # configuration
    epoch = 100
    batch_size = 1
    pretrian_model = 'mc-bert-base'
    train_file = 'data/train.txt'
    save_epoch = 10 # every 10 epoch save checkpoint
    bert_file = '../../pretrained_models/' + pretrian_model
    tokenizer_model_path='../../pretrained_models/pkuseg_medical'
    #
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seg = pkuseg.pkuseg(model_name=tokenizer_model_path)
    config = BertConfig.from_pretrained(bert_file)
    tokenizer = BertTokenizer.from_pretrained(bert_file)
    train_dataset = My_wwm_pretrain_dataset(train_file,tokenizer)
    model = BertForMaskedLM.from_pretrained(bert_file).to(device)
    print('No of parameters: ', model.num_parameters())
    data_collator = DataCollatorForWholeWordMask(
        tokenizer=tokenizer, mlm=True, mlm_probability=0.15
    )
    print('No. of lines: ', len(train_dataset))
    save_step = len(train_dataset) * save_epoch
    tot_step = int(len(train_dataset)/batch_size *  epoch)
    print(f'\n\t***** Running training *****\n'
          f'\tNum examples = {len(train_dataset)}\n'
          f'\tNum Epochs = {epoch}\n'
          f'\tBatch size = {batch_size}\n'
          f'\tTotal optimization steps = {tot_step}\n')

    # official training
    training_args = TrainingArguments(
        output_dir='./outputs/',
        overwrite_output_dir=True,
        num_train_epochs=epoch,
        per_device_train_batch_size=batch_size,
        save_steps=save_step,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )

    trainer.train()
    trainer.save_model('pretrain_outputs/wwm/')



评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YingJingh

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值