真香~BERT在MAC Pytorch的使用

前言

终于,Pytorch也支持MAC的硬件加速,两个字评价一下感受:真香~

周末笔者在自己机器上完成环境安装,笔者机器环境如下:

接着,笔者在该文用卷积、BERT模型对比了有无MAC硬件加速的模型运行时间

软件安装

按照官网给出的命令,即可完成安装MAC硬件加速版pytorch。

https://pytorch.org/get-started/locally/

conda install pytorch torchvision torchaudio -c pytorch

简单测试

利用卷积操作,测试有无硬件加速的效果。

import torch

import time



dev = 'mps:0'

conv = torch.nn.Conv2d(10, 10, 3).to(dev)

img = torch.randn(64, 10, 64, 64).to(dev)



t0 = time.time()

for i in range(1000):

    conv(img)

t1 = time.time()

print('Use mps, time:{}'.format(t1-t0))



dev = 'cpu'

conv = torch.nn.Conv2d(10, 10, 3).to(dev)

img = torch.randn(64, 10, 64, 64).to(dev)



t0 = time.time()

for i in range(1000):

    conv(img)

t1 = time.time()

print('Use cpu, time:{}'.format(t1-t0))

运行结果

BERT测试

使用huggingface的glue代码作示例。

数据准备

运行下述代码完成数据下载工作。

''' Script for downloading all GLUE data.



Note: for legal reasons, we are unable to host MRPC.

You can either use the version hosted by the SentEval team, which is already tokenized,

or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.

For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).

You should then rename and place specific files in a folder (see below for an example).



mkdir MRPC

cabextract MSRParaphraseCorpus.msi -d MRPC

cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt

cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt

rm MRPC/_*

rm MSRParaphraseCorpus.msi



1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.

2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!

'''



import os

import sys

import shutil

import argparse

import tempfile

import urllib.request

import zipfile



TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "QNLI", "RTE", "WNLI", "diagnostic"]

TASK2PATH = {"CoLA": 'https://dl.fbaipublicfiles.com/glue/data/CoLA.zip',

             "SST": 'https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',

             "QQP": 'https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip',

             "STS": 'https://dl.fbaipublicfiles.com/glue/data/STS-B.zip',

             "MNLI": 'https://dl.fbaipublicfiles.com/glue/data/MNLI.zip',

             "QNLI": 'https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip',

             "RTE": 'https://dl.fbaipublicfiles.com/glue/data/RTE.zip',

             "WNLI": 'https://dl.fbaipublicfiles.com/glue/data/WNLI.zip',

             "diagnostic": 'https://dl.fbaipublicfiles.com/glue/data/AX.tsv'}



MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'

MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'





def download_and_extract(task, data_dir):

    print("Downloading and extracting %s..." % task)

    if task == "MNLI":

        print(

            "\tNote (12/10/20): This script no longer downloads SNLI. You will need to manually download and format the data to use SNLI.")

    data_file = "%s.zip" % task

    urllib.request.urlretrieve(TASK2PATH[task], data_file)

    with zipfile.ZipFile(data_file) as zip_ref:

        zip_ref.extractall(data_dir)

    os.remove(data_file)

    print("\tCompleted!")





def format_mrpc(data_dir, path_to_data):

    print("Processing MRPC...")

    mrpc_dir = os.path.join(data_dir, "MRPC")

    if not os.path.isdir(mrpc_dir):

        os.mkdir(mrpc_dir)

    if path_to_data:

        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")

        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")

    else:

        try:

            mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")

            mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")

            URLLIB.urlretrieve(MRPC_TRAIN, mrpc_train_file)

            URLLIB.urlretrieve(MRPC_TEST, mrpc_test_file)

        except urllib.error.HTTPError:

            print("Error downloading MRPC")

            return

    assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file

    assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file



    with io.open(mrpc_test_file, encoding='utf-8') as data_fh, \

            io.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf-8') as test_fh:

        header = data_fh.readline()

        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")

        for idx, row in enumerate(data_fh):

            label, id1, id2, s1, s2 = row.strip().split('\t')

            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))



    try:

        URLLIB.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))

    except KeyError or urllib.error.HTTPError:

        print("\tError downloading standard development IDs for MRPC. You will need to manually split your data.")

        return



    dev_ids = []

    with io.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding='utf-8') as ids_fh:

        for row in ids_fh:

            dev_ids.append(row.strip().split('\t'))



    with io.open(mrpc_train_file, encoding='utf-8') as data_fh, \

            io.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf-8') as train_fh, \

            io.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf-8') as dev_fh:

        header = data_fh.readline()

        train_fh.write(header)

        dev_fh.write(header)

        for row in data_fh:

            label, id1, id2, s1, s2 = row.strip().split('\t')

            if [id1, id2] in dev_ids:

                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))

            else:

                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))



    print("\tCompleted!")





def download_diagnostic(data_dir):

    print("Downloading and extracting diagnostic...")

    if not os.path.isdir(os.path.join(data_dir, "diagnostic")):

        os.mkdir(os.path.join(data_dir, "diagnostic"))

    data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")

    urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)

    print("\tCompleted!")

    return





def get_tasks(task_names):

    task_names = task_names.split(',')

    if "all" in task_names:

        tasks = TASKS

    else:

        tasks = []

        for task_name in task_names:

            assert task_name in TASKS, "Task %s not found!" % task_name

            tasks.append(task_name)

    return tasks





def main(arguments):

    parser = argparse.ArgumentParser()

    parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')

    parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',

                        type=str, default='all')

    parser.add_argument('--path_to_mrpc',

                        help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',

                        type=str, default='')

    args = parser.parse_args(arguments)



    if not os.path.isdir(args.data_dir):

        os.mkdir(args.data_dir)

    tasks = get_tasks(args.tasks)



    for task in tasks:

        if task == 'MRPC':

            format_mrpc(args.data_dir, args.path_to_mrpc)

        elif task == 'diagnostic':

            download_diagnostic(args.data_dir)

        else:

            download_and_extract(task, args.data_dir)





if __name__ == '__main__':

    sys.exit(main(sys.argv[1:]))

环境准备

requirements内容如下:

accelerate

datasets >= 1.8.0

sentencepiece != 0.1.92

scipy

scikit-learn

protobuf

numpy==1.17.3

#torch >= 1.3

代码准备

利用huggingface的run_glue_no_trainer.py

运行脚本如下:

export TASK_NAME=mrpc

python run_glue_no_trainer.py \
  --model_name_or_path Pretrained_LMs/bert-base-cased \
  --task_name $TASK_NAME \
  --max_length 128 \
  --per_device_train_batch_size 32 \
  --learning_rate 2e-5 \
  --num_train_epochs 3 \
  --output_dir ./output/$TASK_NAME/

在代码中修改运行设备方式如下:

    accelerator.state.device = 'mps'

    print('-' * 100)

    print(accelerator.state.device)

    print('-' * 100)

运行结果

CPU下运行时间约1h:

Num processes: 1
Process index: 0
Local process index: 0
Device: cpu
...
07/16/2022 17:13:00 - INFO - __main__ - ***** Running training *****
07/16/2022 17:13:00 - INFO - __main__ -   Num examples = 3668
07/16/2022 17:13:00 - INFO - __main__ -   Num Epochs = 3
07/16/2022 17:13:00 - INFO - __main__ -   Instantaneous batch size per device = 32
07/16/2022 17:13:00 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 32
07/16/2022 17:13:00 - INFO - __main__ -   Gradient Accumulation steps = 1
07/16/2022 17:13:00 - INFO - __main__ -   Total optimization steps = 345
  2%|███▌                                                                                                                                                                                                       | 6/345 [01:06<1:03:49, 11.30s/it]

硬件加速下运行时间约20min:

Num processes: 1
Process index: 0
Local process index: 0
Device: mps
...
07/16/2022 17:14:29 - INFO - __main__ - ***** Running training *****
07/16/2022 17:14:29 - INFO - __main__ -   Num examples = 3668
07/16/2022 17:14:29 - INFO - __main__ -   Num Epochs = 3
07/16/2022 17:14:29 - INFO - __main__ -   Instantaneous batch size per device = 32
07/16/2022 17:14:29 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 32
07/16/2022 17:14:29 - INFO - __main__ -   Gradient Accumulation steps = 1
07/16/2022 17:14:29 - INFO - __main__ -   Total optimization steps = 345
  5%|██████████▋                                                                                                                                                                                                 | 18/345 [01:03<20:14,  3.71s/it]

观察MAC活动监视器,可以看到程序确实有用到GPU硬件加速。

bug fix

在运行过程中出现如下错误:

OMP: Error #15: Initializing libomp.dylib, but found libiomp5.dylib already initialize异常

参照该链接解决了问题,如果Python是基本于Conda安装的,则Conda上的numpy包中的mkl很容易与系统内库发生冲突,可选择update numpy package in Conda或者设置为系统库。

解决方案:降低numpy的版本,此处笔者将版本降低到1.17.3

pip install numpy==1.17.3

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorchBERT中文文本分类是一个存储库,包含了用于中文文本分类的预训练BERT模型的PyTorch实现。该存储库的代码结构包括了pybert和callback两个文件夹,其中pybert文件夹包含了与BERT模型相关的代码文件,而callback文件夹包含了与训练过程中的回调函数相关的代码文件。 首先,构造输入样本,然后进行分词和词向序号的转换。通过使用BertTokenizer.from_pretrained(model_name)方法,我们可以加载预训练的BERT模型,并将输入样本进行分词和词向序号的转换。样本经过分词后,通过tokenizer.convert_tokens_to_ids方法将分词后的文本转换为对应的词向序号,最后将转换后的输入样本转换为torch.LongTensor类型的张量。 与构建训练集数据迭代器类似,我们还需要构建验证集的数据迭代器。首先,将验证集的样本进行分词和词向序号的转换,然后将转换后的输入样本和对应的标签封装为TensorDataset类型的数据集。最后,使用DataLoader方法构建数据迭代器,设置batch_size为1,shuffle参数为True,以便在验证过程中对数据进行洗牌。 总结来说,PyTorchBERT中文文本分类是一个用于中文文本分类的预训练BERT模型的PyTorch实现。我们可以通过构造输入样本和构建数据迭代器来进行文本分类任务。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [BERT-中文文本分类-pytorch:此存储库包含用于文本分类的预训练BERT模型的PyTorch实现](https://download.csdn.net/download/weixin_42107561/15015956)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [Pytorch——BERT 预训练模型及文本分类(情感分类)](https://blog.csdn.net/qq_38563206/article/details/121106374)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值