BERT finetune

BERT finetune

本文代码部分参考github项目:
https://github.com/BSlience/search-engine-zerotohero/tree/main/public/bert_wwm_pretrain

项目概述

本文的主要内容是基于huggingface transformer的chinese-bert-wwm模型,在自己的语料集上进行finetune的整体步骤和代码实现。

关于chinese-bert-wwm:

https://huggingface.co/hfl/chinese-bert-wwm

https://github.com/ymcui/Chinese-BERT-wwm

主要步骤包括:预处理和训练两个部分

预处理(pre-processing)

  1. 下载chinese-bert-wwm模型的预训练词表(vocab.txt)、config.json和pytorch_model.bin;
  2. 读取自己的原始数据集(比如大量的文章、文本),做句子分割,然后保存成语料集;
  3. 根据自己的语料集进行分词(BERT是分割成单个字),并将自己语料集中相比原始的词表多的字(或者词)添加到原始词表中(就是一个扩充操作),然后就生成了自己的词表;

步骤1的下载地址:https://huggingface.co/hfl/chinese-bert-wwm/tree/main

如果生成的语料集比较大,为了后续加载方便可以存储至内存型的数据库中(比如Redis)

训练(train)

  1. 加载BertTokenizer,读取语料集,生成数据;
  2. 针对1中得到的数据,进行填充、截断和mask等操作,借助data collator类;
  3. 加载chinese-bert-wwm模型的预训练权重文件,基于当前数据开始训练(微调);
  4. 保存模型,测试效果。

关于data collator的实现可以查看我的上一篇文章whole word mask;

步骤3加载的预训练权重文件就是上述预处理步骤1下载的,不过要注意把config.json和pytorch_model.bin放在同一个目录下,然后加载这个目录即可。

预处理

这里原始数据(大量的知乎文章)存储在mongodb中,我们读取出来,然后执行预处理步骤2、3

def clean_html_tag(data: str):
    """
    清除给定文本中包含的HTML标签
    :param data:
    :return:
    """
    if isinstance(data, float):
        print(data)
        return ""
    result = bs(data).get_text()
    return result


def read_data_from_mongodb(mongodb_url, db, collection):
    """
    从mongodb中读取数据
    :param mongodb_url: 连接mongodb的地址
    :param db: 数据库名称
    :param collection: 集合名称
    :return:
    """
    client = MongoClient(mongodb_url)
    collection = client[db][collection]
    data = defaultdict(list)  # 可以将字典的value设置list类型
    for each in tqdm(collection.find(batch_size=10)):   # tqdm用于显示进度
        try:
            if each["title"]:
                title = clean_html_tag(each["title"])
                data['title'].append(title)
            if each["excerpt"]:
                summary = clean_html_tag(each['excerpt'])
                data['summary'].append(summary)
            if each['content']:
                clean_content = clean_html_tag(each['content'])
                data['content'].append(clean_content)
        except Exception as _e:
            print(_e)
            print(each)
    return data


def get_split_sentences(data):
    split_sentence = SplitSentence()
    data_new = defaultdict(list)
    for key, value in data.items():
        if key == 'title':
            data_new[key].extend(data[key])
        else:
            sentences_list = []
            for sentences in data[key]:
                for each in split_sentence.split_sentence(sentences):
                    sentences_list.append(each)
                data_new[key].extend(sentences_list)
    return data_new


def pre_processing(redis_host, redis_port, redis_pwd, mongodb_host, db, collection):
    """
    预处理操作,包括从mongodb读取数据,分割、存放至redis
    :param redis_host:
    :param redis_port:
    :param redis_pwd
    :param mongodb_host:
    :param db:
    :param collection:
    :return:
    """
    print('start reading data....')
    data = read_data_from_mongodb(mongodb_host, db, collection)
    print('read data from mongodb finished')
    print('*' * 50)
    print('start saving data')
    data = get_split_sentences(data)
    res = redis.StrictRedis(host=redis_host, port=redis_port, db=0, password=redis_pwd)
    res.set('sentences', json.dumps(data))
    print('save data to redis finished')


def main():
    # 预训练词表存储的位置
    # 下载地址:https://huggingface.co/hfl/chinese-bert-wwm/tree/main
    original_vocab_file_path = 'data/chinese_bert_wwm/vocab.txt'
    # 语料存储位置
    corpus_file_path = 'data/pretrain_corpus.txt'
    # 自己的词表(训练效果一般取决于,词表中的词在语料中出现的次数,如果重要的词在语料中只出现了一次效果就不好)
    vocab_file_path = 'data/vocab.txt'

    # mongo和redis配置
    mongodb_config = {'host': 'mongodb://127.0.0.1:27017', 'db': 'zhihu_new', 'collection': 'articles'}
    redis_config = {'host': '127.0.0.1', 'port': '6379', 'db_index': '0', 'pwd': 'xxx'}
    # 预处理,生成预料(存至Redis,方便后续使用)
    pre_processing(
        redis_host=redis_config['host'], redis_port=redis_config['port'], redis_pwd=redis_config['pwd'],
        mongodb_host=mongodb_config['host'], db=mongodb_config['db'], collection=mongodb_config['collection']
    )
    res = redis.StrictRedis(
        host=redis_config['host'],
        port=redis_config['port'],
        db=0,
        password=redis_config['pwd']
    )
    data = json.loads(res.get('sentences'))
    data_list = []
    for each in data.keys():
        data_list.extend(data[each])
    print('start producing corpus')
    # 保存语料到文件
    save_corpus(data_list, corpus_file_path)
    print('start producing vocab')
    # 生成词表(其实就是单个的字符)
    generate_vocab(data_list, vocab_file_path, original_vocab_file_path)


if __name__ == "__main__":
    main()

这里句子分割单独封装了一个类SplitSentence,具体实现如下

def replace_with_separator(text, separator, regexs):
    replacement = r"\1" + separator + r"\2"
    result = text
    for regex in regexs:
        result = regex.sub(replacement, result)
    return result


class SplitSentence:
    """
    这个分割的方法需要根据你的数据集调整,比如针对一些没有分开的句子,添加分割符到这里
    """
    def __init__(self):
        self.separator = r'@'
        self.re_sentence = re.compile(r'(\S.+?[.!?])(?=\s+|$)|(\S.+?)(?=[\n]|$)', re.UNICODE)
        self.ab_senior = re.compile(r'([A-Z][a-z]{1,2}\.)\s(\w)', re.UNICODE)
        self.ab_acronym = re.compile(r'(\.[a-zA-Z]\.)\s(\w)', re.UNICODE)
        self.undo_ab_senior = re.compile(r'([A-Z][a-z]{1,2}\.)' + self.separator + r'(\w)', re.UNICODE)
        self.undo_ab_acronym = re.compile(r'(\.[a-zA-Z]\.)' + self.separator + r'(\w)', re.UNICODE)

    def split_sentence(self, text, best=True):
        # 句子分割,主要是通过标点符号,如果分割结果发现有些句子分割效果不好,再增加相应的分割符号
        text = re.sub('([。!??])([^”’])', r"\1\n\2", text)
        text = re.sub('(\.{6})([^”’])', r"\1\n\2", text)
        text = re.sub('(…{2})([^”’])', r"\1\n\2", text)
        text = re.sub('([。!??][”’])([^,。!??])', r'\1\n\2', text)
        for chunk in text.split("\n"):
            chunk = chunk.strip()
            if not chunk:
                continue
            if not best:
                yield chunk
                continue
            processed = replace_with_separator(chunk, self.separator, [self.ab_senior, self.ab_acronym])
            for sentence in self.re_sentence.finditer(processed):
                sentence = replace_with_separator(sentence.group(), r" ",
                                                  [self.undo_ab_senior, self.undo_ab_acronym])
                yield sentence

上面是pre_processing()方法的实现,接下来还有保存语料到文件和生成词表的操作

def save_corpus(data, corpus_file_path):
    # 语料其实就是一段段的文本(分割后的一句话)
    with open(corpus_file_path, 'w', encoding='utf-8') as f:
        for row in tqdm(data, total=len(data)):
            f.write(row + '\n')


def generate_vocab(total_data, vocab_file_path, original_vocab_file_path):
    # 以单个的字作为词表(BERT用的是字,也有其他方法是用词的)
    total_tokens = [token for sent in total_data for token in sent]
    counter = Counter(total_tokens)
    vocab = [token for token, freq in counter.items()]
    # 更新下载的预训练词表,也就是把自己词表添加到原始词表中
    # 如果只使用自己的词表,则无法fine tune成功,一定是扩充原始词表
    original_vocab = []
    with open(original_vocab_file_path, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = line.strip('\n')
            original_vocab.append(line)
    need_add_token = [each for each in vocab if each not in original_vocab]

    original_vocab.extend(need_add_token)
    with open(vocab_file_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(original_vocab))

上面的代码是把所有的语料都放到列表中,然后全部操作完再一条条的写到文件中,很容易在执行的时候超出内存卡死,所以我优化了下,读取一条、处理一条、保存一条,运行瞬间丝滑。

"""
根据原项目自己优化的版本
原项目会把所有的句子都放在内存中,很容易电脑内存超出了
这里改为一句一句的读,分割,然后存到文件,并且不使用redis
"""

from bs4 import BeautifulSoup as bs
from pymongo import MongoClient
from tqdm import tqdm

from processing import SplitSentence


def clean_html_tag(data: str):
    """
    清除给定文本中包含的HTML标签
    :param data:
    :return:
    """
    if isinstance(data, float):
        print(data)
        return ""
    result = bs(data).get_text()
    return result


def read_data_from_mongodb(mongodb_url, db, collection):
    """
    从mongodb中读取数据
    :param mongodb_url: 连接mongodb的地址
    :param db: 数据库名称
    :param collection: 集合名称
    :return:
    """
    client = MongoClient(mongodb_url)
    collection = client[db][collection]
    for each in tqdm(collection.find(batch_size=10)):   # tqdm用于显示进度
        data = {}
        try:
            for key in ['title', 'excerpt', 'content']:
                if each[key]:
                    value = clean_html_tag(each[key])
                    data[key] = value
        except Exception as _e:
            print(_e)
            print(each)
        else:
            yield data


def generate_vocab(corpus_file_path, vocab_file_path, original_vocab_file_path):
    # 以单个的字作为词表(BERT用的是字,也有其他方法是用词的)
    vocab = set()  # 词表,不重复
    with open(corpus_file_path, 'r', encoding='utf-8') as fr:
        for line in fr.readlines():
            for word in line:
                vocab.add(word)
    # 更新下载的预训练词表,也就是把自己词表添加到原始词表中
    # 如果只使用自己的词表,则无法fine tune成功,一定是扩充原始词表
    original_vocab = []
    with open(original_vocab_file_path, 'r', encoding='utf-8') as f:
        for line in f.readlines():  # 原始词表中,每行就是一个字符
            line = line.strip('\n')
            original_vocab.append(line)
    need_add_token = vocab.difference(set(original_vocab))
    # new_vocab = vocab.union(original_vocab)  # 并集(这样会打乱原始此表的顺序)
    with open(vocab_file_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(original_vocab))
        f.write('\n')
        f.write('\n'.join(need_add_token))


def main():
    # 预训练词表存储的位置
    # 下载地址:https://huggingface.co/hfl/chinese-bert-wwm/tree/main
    original_vocab_file_path = 'data/chinese_bert_wwm/vocab.txt'
    # 语料存储位置
    corpus_file_path = 'data/pretrain_corpus.txt'
    # 自己的词表(训练效果一般取决于,词表中的词在语料中出现的次数,如果重要的词在语料中只出现了一次效果就不好)
    vocab_file_path = 'data/vocab.txt'

    # mongodb配置
    mongodb_config = {'host': 'mongodb://127.0.0.1:27017', 'db': 'zhihu_new', 'collection': 'articles'}

    # 读取数据预处理,分割句子,生成语料
    data_list = read_data_from_mongodb(
        mongodb_url=mongodb_config['host'],
        db=mongodb_config['db'],
        collection=mongodb_config['collection']
    )
    # 语料文件,每行一句文本
    print('start producing corpus')
    fw = open(corpus_file_path, 'w', encoding='utf-8')

    for data in data_list:  # data_list是一个生成器
        split_sentence = SplitSentence()
        for key, value in data.items():
            if key == 'title':
                fw.write(value+"\n")
            else:
                for each in split_sentence.split_sentence(value):
                    fw.write(each+"\n")
    fw.close()
    print('saving corpus finished')

    print('start producing vocab')
    # 生成词表(其实就是单个的字符)
    generate_vocab(corpus_file_path, vocab_file_path, original_vocab_file_path)


if __name__ == "__main__":
    main()

不过这里需要注意的是,扩充原词表时,不要改变原始词表的顺序,保持原词表顺序不变,在后面添加新词,新添加的顺序无所谓。

训练

先配置一个整体的配置文件,方便后面管理使用

CONFIG = {
    'corpus_file_path': 'data/pretrain_corpus.txt',   # 训练样本(语料)
    'vocab_file_path': 'data/vocab.txt',  # 这个词表是根据自己的数据集扩充过的
    'redis_url': '127.0.0.1',
    'redis_port': 6379,
    'max_seq_len': 102,
    'batch_size': 32,
    'output_dir': 'data/whole_word_mask_bert_output',
    'bert_model_dir': 'data/chinese_bert_wwm',  # 存放config.json和pytorch_model.bin的路径
    'debug': False  # 调试用的
}

训练代码

def seed_everyone(seed_):
    torch.manual_seed(seed_)
    torch.cuda.manual_seed_all(seed_)
    np.random.seed(seed_)
    random.seed(seed_)
    return seed_


def check_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


class SearchDataSet(Dataset):
    def __init__(self, data_dict: dict):
        self.data_dict = data_dict

    # 重写魔术方法,可以通过索引来访问对象
    def __getitem__(self, index: int) -> tuple:
        data = (self.data_dict['input_ids'][index],
                self.data_dict['token_type_ids'][index],
                self.data_dict['attention_mask'][index])
        return data

    def __len__(self) -> int:
        return len(self.data_dict['input_ids'])


def read_data(train_file_path, tokenizer: BertTokenizer, debug=False) -> dict:
    train_data = open(train_file_path, 'r', encoding='utf-8').readlines()
    if debug:
        train_data = train_data[:2000]
    inputs = defaultdict(list)
    for row in tqdm(train_data, desc='Preprocessing train data', total=len(train_data)):
        sentence = row.strip()
        # encode
        inputs_dict = tokenizer.encode_plus(sentence, add_special_tokens=True,
                                            return_token_type_ids=True, return_attention_mask=True)
        inputs['input_ids'].append(inputs_dict['input_ids'])
        inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
        inputs['attention_mask'].append(inputs_dict['attention_mask'])
    return inputs


def main():
    seed_everyone(20220531)  # 统一设置随机数种子
    train_file_path = CONFIG['corpus_file_path']
    # 加载预训练的分词器
    tokenizer = BertTokenizer.from_pretrained(CONFIG['vocab_file_path'], local_file_only=True)
    # 使用分词器读取数据(语料)
    data = read_data(train_file_path, tokenizer, CONFIG['debug'])

    train_dataset = SearchDataSet(data)
    # huggingface transformer中特有的概念,data collator 数据修补(截断和填充)
    # https://huggingface.co/docs/transformers/main/en/main_classes/data_collator#transformers.DataCollatorForLanguageModeling
    # 上面的train_dataset中只有单独的句子,还没有标签(BERT中就是被MASK的值,随机遮蔽一些值,然后预测)
    data_collator = SearchCollator(max_seq_len=CONFIG['max_seq_len'], tokenizer=tokenizer, mlm_probability=0.15)
    # 测试
    data_collator(list(train_dataset))

    output_dir = CONFIG['output_dir']
    model = BertForMaskedLM.from_pretrained(CONFIG['bert_model_dir'])

    model_save_dir = (os.path.join(output_dir, 'best_model_dir'))
    tokenizer_and_config = os.path.join(output_dir, 'tokenizer_and_config')
    check_dir(model_save_dir)
    check_dir(tokenizer_and_config)

    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=20,
        fp16_backend='auto',
        per_device_train_batch_size=128,
        save_steps=500,
        logging_steps=500,
        save_total_limit=5,
        prediction_loss_only=True,
        # report_to='comet_ml',
        logging_first_step=True,
        dataloader_num_workers=4,
        disable_tqdm=False,
        seed=202203
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )

    trainer.train()
    trainer.save_model(model_save_dir)
    tokenizer.save_pretrained(tokenizer_and_config)


if __name__ == "__main__":
    main()

训练的时候可能会出现这个问题

Failed to create a directory: data/whole_word_mask_bert_output\runs\Jun21_16-47-59_user; No such file or directory

我在whole_word_mask_bert_output文件夹下又创建了一个runs文件夹解决了。

另外,如果你的debug配置忘记改为False,那么传入trainer.train()的数据只有2000条,是会报错的,IndexError: index out of range in self

报这个错误是embedding层的张量输入超过了合法范围,embedding层的合法张量输入数值范围应该在[0, num_embeddings-1]的范围内,过大过小都会报错。

关于tokenizer.encode_plus
tokenizer = BertTokenizer.from_pretrained(CONFIG['vocab_file_path'], local_file_only=True)

inputs_dict = tokenizer.encode_plus(
    sentence, 
    add_special_tokens=True,
    return_token_type_ids=True, 
    return_attention_mask=True
)
inputs['input_ids'].append(inputs_dict['input_ids'])
inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
inputs['attention_mask'].append(inputs_dict['attention_mask'])

[CLS] 标志放在第一个句子的首位,经过 BERT 得到的的表征向量 C 可以用于后续的分类任务。
[SEP] 标志用于分开两个输入句子,例如输入句子 A 和 B,要在句子 A,B 后面增加 [SEP] 标志。
[UNK]标志指的是未知字符
[MASK] 标志用于遮盖句子中的一些单词,将单词用 [MASK] 遮盖之后,再利用 BERT 输出的 [MASK] 向量预测单词是什么。

特征抽取

训练(finetune)完成后,就可以使用训练得到的模型来抽取文本的特征,其实这里所说的抽取文本的特征实际就是把自然语言文本转为向量,我们直接使用原始的模型也是可以进行特征抽取的,只不过在自己的数据集上finetune之后效果会更好,而具体效果要在下游的实际任务中才能评估,仅通过finetune后的模型来将文本转为特征向量无法评估效果的好坏。

特征抽取代码示例

"""
使用BERT抽取自然语言特征
过程分为两步:1、在自己的数据集上进行finetune;2、利用finetune后的模型进行抽取
这里是特征抽取过程
"""
import torch
from transformers import BertModel, BertTokenizer


class TextVector:
    def __init__(self, device_id):
        self.tokenizer = BertTokenizer.from_pretrained(
            # 文件夹下包括config.json、pytorch_model.bin、vocab.txt
            './data/best_model_ckpt'  
        )
        self.model = BertModel.from_pretrained(
            './data/best_model_ckpt'
        )
        if torch.cuda.is_available():
            self.device = "cuda:" + str(device_id)
            self.model.to(self.device)
            print(f"bert model 加载到了cuda:{self.device}.")
        else:
            self.device = 'cpu'

    def run(self, data):
        # bert长度限制一般为512,超过长度截断(长度过长会导致参数量太大)
        if len(data) > 510:
            data = data[:510]
        inputs = self.tokenizer(data, return_tensors='pt')
        print(inputs)
        # 将数据放入cpu或者gpu
        inputs = {key: value.to(self.device) for key, value in inputs.items()}
        outputs = self.model(**inputs)
        print(outputs.pooler_output.detach().size())  # tensor类型, torch.Size([1, 768])
        print(outputs.pooler_output.detach().to("cpu").numpy().shape)  # 转为numpy, shape(1, 768)
        # 这里取0号元素后,得到的就是一个列表,reshape(1, -1)转为1行,列自动计算,又变成了(1, 768)
        data_vector = outputs.pooler_output.detach().to("cpu").numpy()[0].reshape(1, -1)
        return data_vector[0].tolist()


text_vector = TextVector(device_id=0)
res = text_vector.run("我喜欢学习")
print(len(res))  # 768
print(res)

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值