看文章标题,mask可以定位到mask language model,代表模型是bert一系列的成果。
mask可以分为token mask和whole Word mask,怎么实现?
两者的区别是什么?
整个实现过程可以借鉴transformer的源码。
whole Word mask
@dataclass
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
"""
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]
batch_input = _collate_batch(input_ids, self.tokenizer)
mask_labels = []
for e in examples:
ref_tokens = []
for id in tolist(e["input_ids"]):
token = self.tokenizer._convert_id_to_token(id)
ref_tokens.append(token)
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = tolist(e["chinese_ref"])
len_seq = len(e["input_ids"])
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = _collate_batch(mask_labels, self.tokenizer)
inputs, labels = self.mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
"""
Get 0/1 labels for masked tokens with whole word mask proxy
"""
cand_indexes = []
for (i, token) in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)
assert len(covered_indexes) == len(masked_lms)
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
return mask_labels
中文实现whole Word mask
# coding=utf-8
'''
Whole word mask for bert
'''
import pkuseg
from transformers import BertConfig, BertForMaskedLM, DataCollatorForWholeWordMask,\
BertTokenizer, TrainingArguments, Trainer
from torch.utils.data import Dataset
from tqdm import tqdm
import torch
class My_wwm_pretrain_dataset(Dataset):
def __init__(self, path, tokenizer, dup_factor=5,max_length=512): # dup_factor : dynamic mask for 5 times
self.examples = []
with open(path,'r',encoding='utf-8') as f:
total_data = f.readlines()
with tqdm(total_data * dup_factor) as loader:
for data in loader:
# clean data
data = data.replace('\n', '').replace('\r', '').replace('\t','').replace(' ','').replace(' ', '')
chinese_ref = self.get_new_segment(data)
input_ids = tokenizer.encode_plus(data,truncation=True,max_length=max_length).input_ids
dict_data = {'input_ids' : input_ids, 'chinese_ref' : chinese_ref}
self.examples.append(dict_data)
loader.set_description(f'loading data')
def get_new_segment(self,segment):
"""
使用分词工具获取 whole word mask
e.g [喜,欢]-> [喜,##欢]
"""
seq_cws = seg.cut("".join(segment)) # 利用 pkuseg 进行医学领域分词
chinese_ref = []
index = 1
for seq in seq_cws:
for i, word in enumerate(seq):
if i>0:
chinese_ref.append(index)
index +=1
return chinese_ref
def __getitem__(self, index):
return self.examples[index]
def __len__(self):
return len(self.examples)
if __name__ == '__main__':
# configuration
epoch = 100
batch_size = 1
pretrian_model = 'mc-bert-base'
train_file = 'data/train.txt'
save_epoch = 10 # every 10 epoch save checkpoint
bert_file = '../../pretrained_models/' + pretrian_model
tokenizer_model_path='../../pretrained_models/pkuseg_medical'
#
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seg = pkuseg.pkuseg(model_name=tokenizer_model_path)
config = BertConfig.from_pretrained(bert_file)
tokenizer = BertTokenizer.from_pretrained(bert_file)
train_dataset = My_wwm_pretrain_dataset(train_file,tokenizer)
model = BertForMaskedLM.from_pretrained(bert_file).to(device)
print('No of parameters: ', model.num_parameters())
data_collator = DataCollatorForWholeWordMask(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
print('No. of lines: ', len(train_dataset))
save_step = len(train_dataset) * save_epoch
tot_step = int(len(train_dataset)/batch_size * epoch)
print(f'\n\t***** Running training *****\n'
f'\tNum examples = {len(train_dataset)}\n'
f'\tNum Epochs = {epoch}\n'
f'\tBatch size = {batch_size}\n'
f'\tTotal optimization steps = {tot_step}\n')
# official training
training_args = TrainingArguments(
output_dir='./outputs/',
overwrite_output_dir=True,
num_train_epochs=epoch,
per_device_train_batch_size=batch_size,
save_steps=save_step,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
trainer.train()
trainer.save_model('pretrain_outputs/wwm/')