ROCm上运行自然语言推断:微调BERT

75 篇文章 0 订阅
7 篇文章 0 订阅

15.7. 自然语言推断:微调BERT — 动手学深度学习 2.0.0 documentation (d2l.ai)

代码

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
                             '225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
                              'c72329e68a732bef0452e4b96a1c341c8910f81f')

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
                          num_heads, num_layers, dropout, max_len, devices):
    data_dir = d2l.download_extract(pretrained_model)
    # 定义空词表以加载预定义词表
    vocab = d2l.Vocab()
    vocab.idx_to_token = json.load(open(os.path.join(data_dir,
        'vocab.json')))
    vocab.token_to_idx = {token: idx for idx, token in enumerate(
        vocab.idx_to_token)}
    bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],
                         ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,
                         num_heads=4, num_layers=2, dropout=0.2,
                         max_len=max_len, key_size=256, query_size=256,
                         value_size=256, hid_in_features=256,
                         mlm_in_features=256, nsp_in_features=256)
    # 加载预训练BERT参数
    bert.load_state_dict(torch.load(os.path.join(data_dir,
                                                 'pretrained.params')))
    return bert, vocab

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
    'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
    num_layers=2, dropout=0.1, max_len=512, devices=devices)

class SNLIBERTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len, vocab=None):
        all_premise_hypothesis_tokens = [[
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(
            *[d2l.tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]

        self.labels = torch.tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        print('read ' + str(len(self.all_token_ids)) + ' examples')

    def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4)  # 使用4个进程
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        all_token_ids = [
            token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (torch.tensor(all_token_ids, dtype=torch.long),
                torch.tensor(all_segments, dtype=torch.long),
                torch.tensor(valid_lens))

    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
                             * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len

    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]

    def __len__(self):
        return len(self.all_token_ids)

# 如果出现显存不足错误,请减少“batch_size”。在原始的BERT模型中,max_len=512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                                   num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                  num_workers=num_workers)

class BERTClassifier(nn.Module):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        self.output = nn.Linear(256, 3)

    def forward(self, inputs):
        tokens_X, segments_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
        return self.output(self.hidden(encoded_X[:, 0, :]))

net = BERTClassifier(bert)

lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
    devices)

代码解析

This block of code aims to use the BERT model for a natural language processing task. Specifically, it appears to address a sentence classification problem using a dataset from the Stanford Natural Language Inference (SNLI) corpus. The code outlines the following steps:
1. Importing necessary libraries and setting up configurations.
2. Downloading pre-trained models (BERT base and BERT small) from a specified d2l.DATA_HUB.
3. Defining a function load_pretrained_model to load the pre-trained BERT model with specified configurations such as the number of hidden units, number of attention heads, etc.
4. Specifying the device(s) to be used for computations (CPU or GPU).
5. Loading and preparing the SNLI dataset for the model by tokenizing and converting into tensors.
6. Defining a SNLIBERTDataset PyTorch dataset class to handle data loading for SNLI.
7. Creating a data iterator for batching and shuffling of the dataset.
8. Defining a BERTClassifier class that extends nn.Module, which includes the encoder from BERT along with an additional output layer for classification.
9. Setting up hyperparameters such as the learning rate and number of epochs, defining the optimizer (Adam in this case), and the loss function (CrossEntropyLoss for classification).
10. Training the BERT classifier model on the SNLI dataset using the defined training iterator, loss function, optimizer, and training utility provided by d2l.
Some notable points from the code snippet:
- d2l seems to be an import from a package related to "Dive into Deep Learning" (a textbook and associated resources for learning deep learning).
- The load_pretrained_model function loads a model and vocabulary from a directory after downloading and unzipping the pre-trained model using d2l utilities.
- The SNLIBERTDataset class processes the SNLI dataset for use with BERT by tokenizing premises and hypotheses from the dataset and creating corresponding token IDs and segment IDs for input to the model.
- A multiprocessing Pool is used to preprocess the data for more efficient data preparation.
- The BERTClassifier calls the encoder from the BERT model and adds a linear output layer to map the encoder outputs to class labels for the SNLI tasks.
- The train_ch13 function from d2l is used to train the model, passing in the defined network, training and test iterators, the loss function, trainer, number of epochs, and device configuration.
To ensure this code runs correctly, you would need to have the required d2l package installed and the necessary hardware (with available memory) to handle the batch size and model parameters. Adjustments may be needed if memory limitations are encountered. 

这段代码的目的是利用BERT模型来进行一项自然语言处理任务。具体而言,它似乎是解决了一个使用斯坦福自然语言推断(SNLI)数据库的句子分类问题。这段代码概述了以下步骤:
1. 导入必要的库并设置配置。
2. 从指定的`d2l.DATA_HUB`下载预训练模型(BERT基础版和BERT小型版)。
3. 定义一个函数`load_pretrained_model`来加载预训练的BERT模型,使用的配置包括隐藏单元、注意力头等的数量。
4. 指定用于计算的设备(CPU或GPU)。
5. 加载并准备SNLI数据库中的数据以供模型使用,包括对句子进行分词并转换成张量。
6. 定义一个PyTorch数据集类`SNLIBERTDataset`来处理SNLI数据库的数据加载。
7. 为数据集创建数据迭代器以进行批处理和混洗。
8. 定义一个扩展自`nn.Module`的`BERTClassifier`类,该类包含BERT的编码器以及用于分类的额外输出层。
9. 设置超参数,如学习率和训练周期数,定义优化器(在此例中用的是`Adam`),以及损失函数(分类任务使用`CrossEntropyLoss`)。
10. 使用定义的训练迭代器、损失函数、优化器以及`d2l`提供的训练工具来对SNLI数据库中的BERT分类器模型进行训练。
代码片段中一些值得注意的点:
- d2l似乎是从与"深入学习系列教材"相关的软件包导入的,教材附属的资源用于学习深度学习。
- load_pretrained_model函数在下载和解压缩预训练模型后使用d2l工具从目录加载模型和词表。
- SNLIBERTDataset类处理SNLI数据库的数据,用于BERT的使用,分词数据集中的前提和假设,并创建相应的令牌ID和段ID作为模型的输入。
- 使用多进程Pool来更高效地进行数据处理。
- BERTClassifier调用了BERT模型的编码器,并添加了一个线性输出层,将编码器输出映射到SNLI任务的类标签。
- 使用`d2l`的`train_ch13`函数来训练模型,传入定义的网络、训练和测试迭代器、损失函数、训练器、训练周期数和设备配置。
为确保这段代码正确运行,你需要安装必要的`d2l`包,并具备处理批处理大小和模型参数所需的硬件(具有可用内存)。如果遇到内存限制,则可能需要进行调整。

  • 23
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

109702008

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值