昇思25天学习打卡营第20天|基于MindSpore的GPT2文本摘要

基于MindSpore的GPT2文本摘要

GPT2文本摘要介绍

GPT-2(Generative Pre-trained Transformer 2)是OpenAI开发的一种基于Transformer架构的预训练语言模型。它是继GPT-1之后的升级版本,通过更大的数据集和更复杂的模型架构,进一步提升了自然语言处理任务的性能。

GPT-2的主要功能之一是文本生成和文本摘要。文本摘要是指从输入的文本中提取关键信息,生成简明扼要的总结或概述。在GPT-2中,这种能力通过预训练阶段学习到的语言模型来实现,模型能够理解和生成自然语言文本的结构和内容,并且能够根据输入文本的上下文生成合理的摘要。

使用GPT-2进行文本摘要时,模型通常会被输入一个长文本或文章,然后生成一个较短的摘要,概括原始文本的主要内容和要点。这种技术对于信息检索、自动化写作和内容生成领域具有重要的应用价值,可以帮助用户快速理解和处理大量文本信息。

GPT2实践

实践环境

python: 3.9.19

安装环境依赖

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install mindnlp

完整的Python环境依赖

pip list
Package                        Version
------------------------------ --------------
absl-py                        2.1.0
addict                         2.4.0
aiofiles                       22.1.0
aiohttp                        3.9.5
aiosignal                      1.3.1
aiosqlite                      0.20.0
altair                         5.3.0
annotated-types                0.7.0
anyio                          4.4.0
argon2-cffi                    23.1.0
argon2-cffi-bindings           21.2.0
arrow                          1.3.0
astroid                        3.2.2
asttokens                      2.0.5
astunparse                     1.6.3
async-timeout                  4.0.3
attrs                          23.2.0
auto-tune                      0.1.0
autopep8                       1.5.5
Babel                          2.15.0
backcall                       0.2.0
beautifulsoup4                 4.12.3
black                          24.4.2
bleach                         6.1.0
certifi                        2024.6.2
cffi                           1.16.0
charset-normalizer             3.3.2
click                          8.1.7
cloudpickle                    3.0.0
colorama                       0.4.6
comm                           0.2.1
contextlib2                    21.6.0
contourpy                      1.2.1
cycler                         0.12.1
dataflow                       0.0.1
datasets                       2.20.0
debugpy                        1.6.7
decorator                      5.1.1
defusedxml                     0.7.1
dill                           0.3.8
dnspython                      2.6.1
download                       0.3.5
easydict                       1.13
email_validator                2.2.0
entrypoints                    0.4
evaluate                       0.4.2
exceptiongroup                 1.2.0
executing                      0.8.3
fastapi                        0.111.0
fastapi-cli                    0.0.4
fastjsonschema                 2.20.0
ffmpy                          0.3.2
filelock                       3.15.3
flake8                         3.8.4
fonttools                      4.53.0
fqdn                           1.5.1
frozenlist                     1.4.1
fsspec                         2024.5.0
gitdb                          4.0.11
GitPython                      3.1.43
gradio                         4.26.0
gradio_client                  0.15.1
h11                            0.14.0
hccl                           0.1.0
hccl-parser                    0.1
httpcore                       1.0.5
httptools                      0.6.1
httpx                          0.27.0
huggingface-hub                0.23.4
hypothesis                     6.105.1
idna                           3.7
importlib-metadata             7.0.1
importlib_resources            6.4.0
iniconfig                      2.0.0
ipykernel                      6.28.0
ipympl                         0.9.4
ipython                        8.15.0
ipython-genutils               0.2.0
ipywidgets                     8.1.3
isoduration                    20.11.0
isort                          5.13.2
jedi                           0.17.2
jieba                          0.42.1
Jinja2                         3.1.4
joblib                         1.4.2
json5                          0.9.25
jsonpointer                    3.0.0
jsonschema                     4.22.0
jsonschema-specifications      2023.12.1
jupyter_client                 7.4.9
jupyter_core                   5.7.2
jupyter-events                 0.10.0
jupyter-lsp                    2.2.5
jupyter-resource-usage         0.7.2
jupyter_server                 2.14.1
jupyter_server_fileid          0.9.2
jupyter-server-mathjax         0.2.6
jupyter_server_terminals       0.5.3
jupyter_server_ydoc            0.8.0
jupyter-ydoc                   0.2.5
jupyterlab                     3.6.7
jupyterlab_code_formatter      2.2.1
jupyterlab_git                 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp                 4.3.0
jupyterlab_pygments            0.3.0
jupyterlab_server              2.27.2
jupyterlab-system-monitor      0.8.0
jupyterlab-topbar              0.6.1
jupyterlab_widgets             3.0.11
kiwisolver                     1.4.5
markdown-it-py                 3.0.0
MarkupSafe                     2.1.5
matplotlib                     3.9.0
matplotlib-inline              0.1.6
mccabe                         0.6.1
mdurl                          0.1.2
mindnlp                        0.3.1
mindspore                      2.2.14
mindvision                     0.1.0
mistune                        3.0.2
ml_collections                 0.1.1
ml-dtypes                      0.4.0
mpmath                         1.3.0
msadvisor                      1.0.0
multidict                      6.0.5
multiprocess                   0.70.16
mypy-extensions                1.0.0
nbclassic                      1.1.0
nbclient                       0.10.0
nbconvert                      7.16.4
nbdime                         4.0.1
nbformat                       5.10.4
nest-asyncio                   1.6.0
notebook                       6.5.7
notebook_shim                  0.2.4
numpy                          1.26.4
op-compile-tool                0.1.0
op-gen                         0.1
op-test-frame                  0.1
opc-tool                       0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python                  4.10.0.84
opencv-python-headless         4.10.0.84
orjson                         3.10.5
overrides                      7.7.0
packaging                      23.2
pandas                         2.2.2
pandocfilters                  1.5.1
parso                          0.7.1
pathlib2                       2.3.7.post1
pathspec                       0.12.1
pexpect                        4.8.0
pickleshare                    0.7.5
pillow                         10.3.0
pip                            24.1
platformdirs                   4.2.2
pluggy                         1.5.0
prometheus_client              0.20.0
prompt-toolkit                 3.0.43
protobuf                       5.27.1
psutil                         5.9.0
ptyprocess                     0.7.0
pure-eval                      0.2.2
pyarrow                        16.1.0
pyarrow-hotfix                 0.6
pycodestyle                    2.6.0
pycparser                      2.22
pyctcdecode                    0.5.0
pydantic                       2.7.4
pydantic_core                  2.18.4
pydocstyle                     6.3.0
pydub                          0.25.1
pyflakes                       2.2.0
Pygments                       2.15.1
pygtrie                        2.5.0
pylint                         3.2.3
pyparsing                      3.1.2
pytest                         7.2.0
python-dateutil                2.9.0.post0
python-dotenv                  1.0.1
python-json-logger             2.0.7
python-jsonrpc-server          0.4.0
python-language-server         0.36.2
python-multipart               0.0.9
pytoolconfig                   1.3.1
pytz                           2024.1
PyYAML                         6.0.1
pyzmq                          25.1.2
referencing                    0.35.1
regex                          2024.5.15
requests                       2.32.3
rfc3339-validator              0.1.4
rfc3986-validator              0.1.1
rich                           13.7.1
rope                           1.13.0
rpds-py                        0.18.1
ruff                           0.4.10
safetensors                    0.4.3
schedule-search                0.0.1
scikit-learn                   1.5.0
scipy                          1.13.1
semantic-version               2.10.0
Send2Trash                     1.8.3
sentencepiece                  0.2.0
setuptools                     69.5.1
shellingham                    1.5.4
six                            1.16.0
smmap                          5.0.1
sniffio                        1.3.1
snowballstemmer                2.2.0
sortedcontainers               2.4.0
soupsieve                      2.5
stack-data                     0.2.0
starlette                      0.37.2
sympy                          1.12.1
synr                           0.5.0
te                             0.4.0
terminado                      0.18.1
threadpoolctl                  3.5.0
tinycss2                       1.3.0
tokenizers                     0.15.0
toml                           0.10.2
tomli                          2.0.1
tomlkit                        0.12.0
toolz                          0.12.1
tornado                        6.4.1
tqdm                           4.66.4
traitlets                      5.14.3
typer                          0.12.3
types-python-dateutil          2.9.0.20240316
typing_extensions              4.11.0
tzdata                         2024.1
ujson                          5.10.0
uri-template                   1.3.0
urllib3                        2.2.2
uvicorn                        0.30.1
uvloop                         0.19.0
watchfiles                     0.22.0
wcwidth                        0.2.5
webcolors                      24.6.0
webencodings                   0.5.1
websocket-client               1.8.0
websockets                     11.0.3
wheel                          0.43.0
widgetsnbextension             4.0.11
xxhash                         3.4.1
y-py                           0.6.2
yapf                           0.40.2
yarl                           1.9.4
ypy-websocket                  0.8.4
zipp                           3.17.0

实践代码

from mindnlp.utils import http_get

# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

from mindspore.dataset import TextFileDataset

# load dataset
# dataset = TextFileDataset(str(path), shuffle=False)  # 跑的太慢,所以不用全量数据测试
dataset = TextFileDataset(str(path), shuffle=True,num_samples=5000)
# dataset.get_dataset_size()

# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

2. 数据预处理

    原始数据格式:
    
    article: [CLS] article_context [SEP]
    summary: [CLS] summary_context [SEP]
   
    预处理后的数据格式:


    [CLS] article_context [SEP] summary_context [SEP]
 

import json
import numpy as np

# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])

    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset


# 因GPT2无中文的tokenizer,我们使用BertTokenizer替代。



from mindnlp.transformers import BertTokenizer

# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)

next(train_dataset.create_tuple_iterator())

模型构建

  1. 构建GPT2ForSummarization模型,注意shift right的操作。
# 

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

# 动态学习率
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)

# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))
number of model parameters: 102068736
from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度

注:建议使用较高规格的算力,训练时间较长

trainer.run(tgt_columns="labels")

模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)

    return dataset


test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)

print(next(test_dataset.create_tuple_iterator(output_numpy=True)))

model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)

model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 1:
        break

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值