昇思MindSpore学习笔记2-02 LLM原理和实践--基于MindSpore的GPT2文本摘要

摘要:

介绍了昇思MindSpore AI框架用50000个nlpcc2017新闻正文及其摘要数据样本进行GPT2生成式训练的方法和步骤。

一、环境配置

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp

输出:

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting tokenizers==0.15.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/14/cf/883acc48862589f9d54c239a9108728db5b75cd6c0949b92c72aae8e044c/tokenizers-0.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (3.8 MB)
     ━━━━━━━━━━━━━━━━━ 3.8/3.8 MB 14.6 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: huggingface_hub<1.0,>=0.16.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from tokenizers==0.15.0) (0.23.4)
Requirement already satisfied: filelock in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (3.15.3)
Requirement already satisfied: fsspec>=2023.5.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (2024.6.0)
Requirement already satisfied: packaging>=20.9 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (23.2)
Requirement already satisfied: pyyaml>=5.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (6.0.1)
Requirement already satisfied: requests in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (2.32.3)
Requirement already satisfied: tqdm>=4.42.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (4.66.4)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (4.11.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (2024.6.2)
Installing collected packages: tokenizers
Successfully installed tokenizers-0.15.0

[notice] A new release of pip is available: 24.1 -> 24.1.1
[notice] To update, run: python -m pip install --upgrade pip
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting mindnlp
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/72/37/ef313c23fd587c3d1f46b0741c98235aecdfd93b4d6d446376f3db6a552c/mindnlp-0.3.1-py3-none-any.whl (5.7 MB)
     ━━━━━━━━━━━━━━━━━ 5.7/5.7 MB 14.6 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: mindspore in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.2.14)
Requirement already satisfied: tqdm in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (4.66.4)
Requirement already satisfied: requests in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.32.3)
Collecting datasets (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/60/2d/963b266bb8f88492d5ab4232d74292af8beb5b6fdae97902df9e284d4c32/datasets-2.20.0-py3-none-any.whl (547 kB)
     ━━━━━━━━━━━━━━━━━ 547.8/547.8 kB 17.1 MB/s eta 0:00:00
Collecting evaluate (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/d6/ff9baefc8fc679dcd9eb21b29da3ef10c81aa36be630a7ae78e4611588e1/evaluate-0.4.2-py3-none-any.whl (84 kB)
     ━━━━━━━━━━━━━━━━━ 84.1/84.1 kB 24.5 MB/s eta 0:00:00
Requirement already satisfied: tokenizers in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.15.0)
Collecting safetensors (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/02/28e6280ed0f1bde89eed644b80f2ece4e5ae212dc9ee70d7f56fadc93602/safetensors-0.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.2 MB)
     ━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 14.5 MB/s eta 0:00:00a 0:00:01
Collecting sentencepiece (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a3/69/e96ef68261fa5b82379fdedb325ceaf1d353c6e839ec346d8244e0da5f2f/sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 14.4 MB/s eta 0:00:00a 0:00:01
Collecting regex (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/70/70/fea4865c89a841432497d1abbfd53878513b55c6543245fabe31cf8df0b8/regex-2024.5.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (774 kB)
     ━━━━━━━━━━━━━━━━━ 774.7/774.7 kB 14.5 MB/s eta 0:00:00a 0:00:01
Collecting addict (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl (3.8 kB)
Collecting ml-dtypes (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/50/96/13d7c3cc82d5ef597279216cf56ff461f8b57e7096a3ef10246a83ca80c0/ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (2.2 MB)
     ━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 14.1 MB/s eta 0:00:00a 0:00:01
Collecting pyctcdecode (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a5/8a/93e2118411ae5e861d4f4ce65578c62e85d0f1d9cb389bd63bd57130604e/pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)
Collecting jieba (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/cb/18eeb235f833b726522d7ebed54f2278ce28ba9438e3135ab0278d9792a2/jieba-0.42.1.tar.gz (19.2 MB)
     ━━━━━━━━━━━━━━━━━ 19.2/19.2 MB 15.6 MB/s eta 0:00:0000:0100:01
  Preparing metadata (setup.py) ... done
Collecting pytest==7.2.0 (from mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/68/a5eb36c3a8540594b6035e6cdae40c1ef1b6a2bfacbecc3d1a544583c078/pytest-7.2.0-py3-none-any.whl (316 kB)
     ━━━━━━━━━━━━━━━━━ 316.8/316.8 kB 15.6 MB/s eta 0:00:00
Requirement already satisfied: attrs>=19.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2.0)
Requirement already satisfied: iniconfig in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.0)
Requirement already satisfied: packaging in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2)
Requirement already satisfied: pluggy<2.0,>=0.12 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.5.0)
Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.2.0)
Requirement already satisfied: tomli>=1.0.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.1)
Requirement already satisfied: filelock in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.15.3)
Requirement already satisfied: numpy>=1.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (1.26.4)
Collecting pyarrow>=15.0.0 (from datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/60/cc0645eb4ef73f88847e40a7f9d238bae6b7409d6c1f6a5d200d8ade1f09/pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl (38.1 MB)
     ━━━━━━━━━━━━━━━━━ 38.1/38.1 MB 13.9 MB/s eta 0:00:0000:0100:01
Collecting pyarrow-hotfix (from datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e4/f4/9ec2222f5f5f8ea04f66f184caafd991a39c8782e31f5b0266f101cb68ca/pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.3.8)
Requirement already satisfied: pandas in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (2.2.2)
Collecting xxhash (from datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7c/b9/93f860969093d5d1c4fa60c75ca351b212560de68f33dc0da04c89b7dc1b/xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (220 kB)
     ━━━━━━━━━━━━━━━━━ 220.6/220.6 kB 14.9 MB/s eta 0:00:00
Collecting multiprocess (from datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl (133 kB)
     ━━━━━━━━━━━━━━━━━ 133.4/133.4 kB 15.4 MB/s eta 0:00:00
Collecting fsspec<=2024.5.0,>=2023.1.0 (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ba/a3/16e9fe32187e9c8bc7f9b7bcd9728529faa725231a0c96f2f98714ff2fc5/fsspec-2024.5.0-py3-none-any.whl (316 kB)
     ━━━━━━━━━━━━━━━━━ 316.1/316.1 kB 15.6 MB/s eta 0:00:00
Collecting aiohttp (from datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/eb/45/eebe8d2215328434f33ccb44a05d2741ff7ed4b96b56ca507e2ecf598b73/aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.2 MB)
     ━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 17.2 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: huggingface-hub>=0.21.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.23.4)
Requirement already satisfied: pyyaml>=5.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (6.0.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2024.6.2)
Requirement already satisfied: protobuf>=3.13.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.27.1)
Requirement already satisfied: asttokens>=2.0.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (2.0.5)
Requirement already satisfied: pillow>=6.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (10.3.0)
Requirement already satisfied: scipy>=1.5.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.13.1)
Requirement already satisfied: psutil>=5.6.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.9.0)
Requirement already satisfied: astunparse>=1.6.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.6.3)
Collecting pygtrie<3.0,>=2.1 (from pyctcdecode->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ec/cd/bd196b2cf014afb1009de8b0f05ecd54011d881944e62763f3c1b1e8ef37/pygtrie-2.5.0-py3-none-any.whl (25 kB)
Collecting hypothesis<7,>=6.14 (from pyctcdecode->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/ea/526a7a629fcf6c78a1a6d37f988ca7e02e5b5785ec4de8a194deb40529f4/hypothesis-6.104.2-py3-none-any.whl (462 kB)
     ━━━━━━━━━━━━━━━━━ 462.4/462.4 kB 14.7 MB/s eta 0:00:00
Requirement already satisfied: six in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from asttokens>=2.0.4->mindspore->mindnlp) (1.16.0)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore->mindnlp) (0.43.0)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Collecting frozenlist>=1.1.1 (from aiohttp->datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/57/15/172af60c7e150a1d88ecc832f2590721166ae41eab582172fe1e9844eab4/frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (239 kB)
     ━━━━━━━━━━━━━━━━━ 239.4/239.4 kB 16.8 MB/s eta 0:00:00
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d0/10/2ff646c471e84af25fe8111985ffb8ec85a3f6e1ade8643bfcfcc0f4d2b1/multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (125 kB)
     ━━━━━━━━━━━━━━━━━ 125.9/125.9 kB 15.0 MB/s eta 0:00:00
Collecting yarl<2.0,>=1.0 (from aiohttp->datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/d6/5b30ae1d8a13104ee2ceb649f28f2db5ad42afbd5697fd0fc61528bb112c/yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (300 kB)
     ━━━━━━━━━━━━━━━━━ 300.9/300.9 kB 15.5 MB/s eta 0:00:00
Collecting async-timeout<5.0,>=4.0 (from aiohttp->datasets->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub>=0.21.2->datasets->mindnlp) (4.11.0)
Collecting sortedcontainers<3.0.0,>=2.1.0 (from hypothesis<7,>=6.14->pyctcdecode->mindnlp)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)
Building wheels for collected packages: jieba
  Building wheel for jieba (setup.py) ... done
  Created wheel for jieba: filename=jieba-0.42.1-py3-none-any.whl size=19314459 sha256=36141df545311d8b3a7ef64258a655353548ab32b87f621e9c9a14c890232fa4
  Stored in directory: /home/nginx/.cache/pip/wheels/1a/76/68/b6d79c4db704bb18d54f6a73ab551185f4711f9730c0c15d97
Successfully built jieba
Installing collected packages: sortedcontainers, sentencepiece, pygtrie, jieba, addict, xxhash, safetensors, regex, pytest, pyarrow-hotfix, pyarrow, multiprocess, multidict, ml-dtypes, hypothesis, fsspec, frozenlist, async-timeout, yarl, pyctcdecode, aiosignal, aiohttp, datasets, evaluate, mindnlp
  Attempting uninstall: pytest
    Found existing installation: pytest 8.0.0
    Uninstalling pytest-8.0.0:
      Successfully uninstalled pytest-8.0.0
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2024.6.0
    Uninstalling fsspec-2024.6.0:
      Successfully uninstalled fsspec-2024.6.0
Successfully installed addict-2.4.0 aiohttp-3.9.5 aiosignal-1.3.1 async-timeout-4.0.3 datasets-2.20.0 evaluate-0.4.2 frozenlist-1.4.1 fsspec-2024.5.0 hypothesis-6.104.2 jieba-0.42.1 mindnlp-0.3.1 ml-dtypes-0.4.0 multidict-6.0.5 multiprocess-0.70.16 pyarrow-16.1.0 pyarrow-hotfix-0.6 pyctcdecode-0.5.0 pygtrie-2.5.0 pytest-7.2.0 regex-2024.5.15 safetensors-0.4.3 sentencepiece-0.2.0 sortedcontainers-2.4.0 xxhash-3.4.1 yarl-1.9.4

[notice] A new release of pip is available: 24.1 -> 24.1.1
[notice] To update, run: python -m pip install --upgrade pip

二、数据集加载与处理

1.下载数据集

nlpcc2017摘要数据

数据内容

新闻正文及其摘要

50000个样本

from mindnlp.utils import http_get
​
# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

输出:

Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
Loading model cost 1.037 seconds.
Prefix dict has been built successfully.
 ━ 145M/0.00 [00:06<00:00, 22.7MB/s]

2.数据集加载

from mindspore.dataset import TextFileDataset
​
# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()

输出:

50000

分离训练数据集和测试数据集

# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

3.数据预处理

原始数据格式:

article: [CLS] article_context [SEP]

summary: [CLS] summary_context [SEP]

预处理后的数据格式:

[CLS] article_context [SEP] summary_context [SEP]

import json
import numpy as np
​
# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])
​
    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])
​
    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)
​
    return dataset

使用BertTokenizer分词标记

from mindnlp.transformers import BertTokenizer
​
# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

输出:

100%━━━━━━━━━━━━━━━━━ 49.0/49.0 [00:00<00:00, 4.10kB/s]
 ━ 107k/0.00 [00:00<00:00, 90.2kB/s]
 ━ 263k/0.00 [00:00<00:00, 486kB/s]
 ━━━━━━━━━━━━━━━━━ 624/? [00:00<00:00, 61.7kB/s]
21128

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
next(train_dataset.create_tuple_iterator())

输出:

[Tensor(shape=[4, 1024], dtype=Int64, value=
 [[ 101, 1724, 3862 ...    0,    0,    0],
  [ 101,  704, 3173 ...    0,    0,    0],
  [ 101, 1079, 2159 ... 1745, 8021,  102],
  [ 101, 1355, 2357 ...    0,    0,    0]]),
 Tensor(shape=[4, 1024], dtype=Int64, value=
 [[ 101, 1724, 3862 ...    0,    0,    0],
  [ 101,  704, 3173 ...    0,    0,    0],
  [ 101, 1079, 2159 ... 1745, 8021,  102],
  [ 101, 1355, 2357 ...    0,    0,    0]])]

三、模型构建

1.构建GPT2ForSummarization模型

?shift right操作

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel
​
class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), 
 shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

2.动态学习率

from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

四、模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4
​
num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel
​
config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)
​
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))

输出:

number of model parameters: 102068736

from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback
​
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)
​
trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度

输出:

Trainer will use 'StaticLossScaler' with `scale_value=2 ** 10` when `loss_scaler` is None.

注:建议使用较高规格的算力,训练时间较长

内存需求大,找资源空闲时跑,居然预计8小时

trainer.run(tgt_columns="labels")

输出:

The train will start from the checkpoint saved in 'checkpoint'.
Epoch 0: ━7%                          822/11250 [40:56<8:12:17,  2.83s/it, loss=6.386638]
-
[ERROR] CORE(232,ffff88a68930,python):2024-07-02-03:41:10.735.886 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_232/729418715.py]
[ERROR] CORE(232,ffff88a68930,python):2024-07-02-03:41:10.735.980 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_232/729418715.py]
[ERROR] CORE(232,ffff88a68930,python):2024-07-02-03:41:10.736.609 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_232/729418715.py]
[ERROR] CORE(232,ffff88a68930,python):2024-07-02-03:41:10.736.688 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_232/729418715.py]

五、模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])
​
    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']
​
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)
​
    return dataset
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 1:
        break

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

muren

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值