如何用自己的数据集进行bert预训练

# coding:utf-8

import os
import pickle
import random
import logging
import warnings
from argparse import ArgumentParser

import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple
from collections import defaultdict

import torch
from torch.utils.data import Dataset

from transformers import (
    BertTokenizer,
    TrainingArguments,
    Trainer
)
from src.util.modeling.modeling_nezha.modeling import NeZhaConfig, NeZhaForMaskedLM

warnings.filterwarnings('ignore')

logging.basicConfig()
logger = logging.getLogger('')
logger.setLevel(logging.INFO)


def save_pickle(dic, save_path):
    with open(save_path, 'wb') as f:
        pickle.dump(dic, f)


def load_pickle(load_path):
    with open(load_path, 'rb') as f:
        message_dict = pickle.load(f)
    return message_dict


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def read_data(args, tokenizer: BertTokenizer) -> dict:
    pretrain_df = pd.read_csv(args.pretrain_data_path, header=None, sep='\t')

    inputs = defaultdict(list)
    for i, row in tqdm(pretrain_df.iterrows(), desc='', total=len(pretrain_df)):
        sentence = row[0].strip()
        inputs_dict = tokenizer.encode_plus(sentence, add_special_tokens=True,
                                            return_token_type_ids=True, return_attention_mask=True)
        inputs['input_ids'].append(inputs_dict['input_ids'])
        inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
        inputs['attention_mask'].append(inputs_dict['attention_mask'])

    os.makedirs(os.path.dirname(args.data_cache), exist_ok=True)
    save_pickle(inputs, args.data_cache)

    return inputs


class PretrainDataset(Dataset):
    def __init__(self, data_dict: dict):
        super(Dataset, self).__init__()
        self.data_dict = data_dict

    def __getitem__(self, index: int) -> tuple:
        data = (self.data_dict['input_ids'][index],
                self.data_dict['token_type_ids'][index],
                self.data_dict['attention_mask'][index])

        return data

    def __len__(self) -> int:
        return len(self.data_dict['input_ids'])


class PretrainDataCollator:
    def __init__(self, max_seq_len: int, tokenizer: BertTokenizer, mlm_probability=0.15):
        self.max_seq_len = max_seq_len
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability
        self.special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id}

    def pad_and_truncate(self, input_ids_list, token_type_ids_list,
                         attention_mask_list, max_seq_len):
        input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
        token_type_ids = torch.zeros_like(input_ids)
        attention_mask = torch.zeros_like(input_ids)
        for i in range(len(input_ids_list)):
            seq_len = len(input_ids_list[i])
            if seq_len <= max_seq_len:
                input_ids[i, :seq_len] = torch.tensor(input_ids_list[i], dtype=torch.long)
                token_type_ids[i, :seq_len] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
                attention_mask[i, :seq_len] = torch.tensor(attention_mask_list[i], dtype=torch.long)
            else:
                input_ids[i] = torch.tensor(input_ids_list[i][:max_seq_len - 1] + [self.tokenizer.sep_token_id],
                                            dtype=torch.long)
                token_type_ids[i] = torch.tensor(token_type_ids_list[i][:max_seq_len], dtype=torch.long)
                attention_mask[i] = torch.tensor(attention_mask_list[i][:max_seq_len], dtype=torch.long)
        return input_ids, token_type_ids, attention_mask

    def _ngram_mask(self, input_ids, max_seq_len):
        cand_indexes = []
        for (i, id_) in enumerate(input_ids):
            if id_ in self.special_token_ids:
                continue
            cand_indexes.append([i])
        num_to_predict = max(1, int(round(len(input_ids) * self.mlm_probability)))

        max_ngram = 3
        ngrams = np.arange(1, max_ngram + 1, dtype=np.int64)
        pvals = 1. / np.arange(1, max_ngram + 1)
        pvals /= pvals.sum(keepdims=True)

        ngram_indexes = []
        for idx in range(len(cand_indexes)):
            ngram_index = []
            for n in ngrams:
                ngram_index.append(cand_indexes[idx:idx + n])
            ngram_indexes.append(ngram_index)
        np.random.shuffle(ngram_indexes)

        covered_indexes = set()

        for cand_index_set in ngram_indexes:
            if len(covered_indexes) >= num_to_predict:
                break
            if not cand_index_set:
                continue
            for index_set in cand_index_set[0]:
                for index in index_set:
                    if index in covered_indexes:
                        continue
            n = np.random.choice(ngrams[:len(cand_index_set)],
                                 p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True))
            index_set = sum(cand_index_set[n - 1], [])
            n -= 1
            while len(covered_indexes) + len(index_set) > num_to_predict:
                if n == 0:
                    break
                index_set = sum(cand_index_set[n - 1], [])
                n -= 1
            if len(covered_indexes) + len(index_set) > num_to_predict:
                continue
            is_any_index_covered = False
            for index in index_set:
                if index in covered_indexes:
                    is_any_index_covered = True
                    break
            if is_any_index_covered:
                continue
            for index in index_set:
                covered_indexes.add(index)

        mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_ids))]
        mask_labels += [0] * (max_seq_len - len(mask_labels))

        return torch.tensor(mask_labels[:max_seq_len])

    def ngram_mask(self, input_ids_list: List[list], max_seq_len: int):
        mask_labels = []
        for i, input_ids in enumerate(input_ids_list):
            mask_label = self._ngram_mask(input_ids, max_seq_len)
            mask_labels.append(mask_label)
        return torch.stack(mask_labels, dim=0)

    def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> \
            Tuple[torch.Tensor, torch.Tensor]:

        labels = inputs.clone()
        probability_matrix = mask_labels
        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        ]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        masked_indices = probability_matrix.bool()
        labels[~masked_indices] = -100
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        indices_random = torch.bernoulli(
            torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        return inputs, labels

    def __call__(self, examples: list) -> dict:
        input_ids_list, token_type_ids_list, attention_mask_list = list(zip(*examples))
        cur_max_seq_len = max(len(input_id) for input_id in input_ids_list)
        max_seq_len = min(cur_max_seq_len, self.max_seq_len)

        input_ids, token_type_ids, attention_mask = self.pad_and_truncate(input_ids_list,
                                                                          token_type_ids_list,
                                                                          attention_mask_list,
                                                                          max_seq_len)
        batch_mask = self.ngram_mask(input_ids_list, max_seq_len)
        input_ids, mlm_labels = self.mask_tokens(input_ids, batch_mask)
        data_dict = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'labels': mlm_labels
        }

        return data_dict


def main(model_type):
    parser = ArgumentParser()

    parser.add_argument('--pretrain_data_path', type=str, default='../../../user_data/process_data/pretrain.txt')
    parser.add_argument('--pretrain_model_path', type=str,
                        default=f'../../../user_data/{model_type}/pretrain_model/nezha-cn-base')
    parser.add_argument('--data_cache', type=str,
                        default=f'../../../user_data/{model_type}/process_data/pkl/pretrain.pkl')
    parser.add_argument('--vocab_path', type=str, default=f'../../../user_data/{model_type}/tokenizer/vocab.txt')
    parser.add_argument('--save_path', type=str, default='./model')
    parser.add_argument('--record_save_path', type=str, default='./record')
    parser.add_argument('--mlm_probability', type=float, default=0.15)
    parser.add_argument('--num_train_epochs', type=int, default=100)
    parser.add_argument('--seq_length', type=int, default=128)
    parser.add_argument('--batch_size', type=int, default=192)
    parser.add_argument('--learning_rate', type=float, default=8e-5)
    parser.add_argument('--save_steps', type=int, default=5000)
    parser.add_argument('--ckpt_save_limit', type=int, default=5)
    parser.add_argument('--logging_steps', type=int, default=500)
    parser.add_argument('--seed', type=int, default=2021)
    parser.add_argument('--fp16', type=str, default=True)
    parser.add_argument('--fp16_backend', type=str, default='amp')

    warnings.filterwarnings('ignore')
    args = parser.parse_args()

    seed_everything(args.seed)

    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    os.makedirs(os.path.dirname(args.record_save_path), exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.vocab_path)
    model_config = NeZhaConfig.from_pretrained(args.pretrain_model_path)

    if not os.path.exists(args.data_cache):
        data = read_data(args, tokenizer)
    else:
        data = load_pickle(args.data_cache)

    data_collator = PretrainDataCollator(max_seq_len=args.seq_length,
                                         tokenizer=tokenizer,
                                         mlm_probability=args.mlm_probability)
    model = NeZhaForMaskedLM.from_pretrained(pretrained_model_name_or_path=args.pretrain_model_path,
                                             config=model_config)
    model.resize_token_embeddings(tokenizer.vocab_size)
    dataset = PretrainDataset(data)

    training_args = TrainingArguments(
        seed=args.seed,
        fp16=args.fp16,
        fp16_backend=args.fp16_backend,
        save_steps=args.save_steps,
        prediction_loss_only=True,
        logging_steps=args.logging_steps,
        output_dir=args.record_save_path,
        learning_rate=args.learning_rate,
        save_total_limit=args.ckpt_save_limit,
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.batch_size
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator
    )

    trainer.train()
    trainer.save_model(args.save_path)
    tokenizer.save_vocabulary(args.save_path)


if __name__ == '__main__':
    main('nezha')

  • 3
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值