Colosal-AI复现流程

1 环境搭建

1.1 cuda环境

root@LAPTOP-3SUHS40U:/home/work/ColossalAI# lsb_release -a
LSB Version:    core-11.1.0ubuntu2-noarch:security-11.1.0ubuntu2-noarch
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.6 LTS
Release:        20.04
Codename:       focal
root@LAPTOP-3SUHS40U:/home/work/ColossalAI# nvidia-smi
Sun May 21 05:56:52 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.50                 Driver Version: 531.79       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2060         On | 00000000:01:00.0  On |                  N/A |
| N/A   47C    P8                7W /  N/A|   5760MiB /  6144MiB |     10%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A        23      G   /Xwayland                                 N/A      |
|    0   N/A  N/A    306013      C   /python3.10                               N/A      |
+---------------------------------------------------------------------------------------+

1.2 python环境

root@LAPTOP-3SUHS40U:/home/work/ColossalAI# python --version
Python 3.10.10

1.3 python package 环境

root@LAPTOP-3SUHS40U:/home/work/ColossalAI# pip list
Package                       Version          Editable project location
----------------------------- ---------------- ------------------------------------------------
absl-py                       1.4.0
aiofiles                      23.1.0
aiohttp                       3.8.4
aiosignal                     1.3.1
alabaster                     0.7.13
altair                        5.0.0
anyio                         3.6.2
appdirs                       1.4.4
arrow                         1.2.3
astunparse                    1.6.3
async-timeout                 4.0.2
attrs                         23.1.0
Babel                         2.12.1
bcrypt                        4.0.1
benepar                       0.2.0
blessed                       1.20.0
blis                          0.7.9
boltons                       23.0.0
brotlipy                      0.7.0
cachetools                    5.3.0
catalogue                     2.0.8
certifi                       2022.12.7
cffi                          1.15.1
cfgv                          3.3.1
charset-normalizer            2.0.4
click                         8.1.3
cmake                         3.26.3
coati                         1.0.0
colossalai                    0.2.8
conda                         23.3.1
conda-content-trust           0.1.3
conda-package-handling        2.0.2
conda_package_streaming       0.7.0
confection                    0.0.4
contexttimer                  0.3.3
coverage                      7.2.5
cryptography                  39.0.1
cycler                        0.11.0
cymem                         2.0.7
Cython                        0.29.34
dataclasses-json              0.5.7
datasets                      2.12.0
diffusers                     0.16.1
dill                          0.3.6
distlib                       0.3.6
distro                        1.8.0
dnspython                     2.3.0
docker                        6.0.1
docker-pycreds                0.4.0
docstring-parser              0.8.1
docutils                      0.18.1
einops                        0.6.1
et-xmlfile                    1.1.0
exceptiongroup                1.1.1
expecttest                    0.1.4
fabric                        3.0.1
fastapi                       0.95.1
fbgemm-gpu                    0.4.1
ffmpy                         0.3.0
filelock                      3.12.0
flash-attn                    0.1
fonttools                     4.39.3
frozenlist                    1.3.3
fsspec                        2023.4.0
gitdb                         4.0.10
GitPython                     3.1.31
google-auth                   2.17.3
google-auth-oauthlib          0.4.6
gpustat                       1.1
gradio                        3.30.0
gradio_client                 0.2.4
greenlet                      2.0.2
grpcio                        1.54.0
h11                           0.14.0
httpcore                      0.17.0
httpx                         0.24.0
huggingface-hub               0.14.1
hypothesis                    6.75.1
identify                      2.5.23
idna                          3.4
imagesize                     1.4.1
importlib-metadata            6.6.0
iniconfig                     2.0.0
invoke                        2.1.1
iopath                        0.1.10
Jinja2                        3.1.2
joblib                        1.2.0
jsonpatch                     1.32
jsonpointer                   2.1
jsonschema                    4.17.3
kiwisolver                    1.4.4
langchain                     0.0.161
langcodes                     3.3.0
libcst                        0.4.9
linkify-it-py                 2.0.2
loralib                       0.1.1
Markdown                      3.4.3
markdown-it-py                2.2.0
MarkupSafe                    2.1.2
marshmallow                   3.19.0
marshmallow-enum              1.5.1
matplotlib                    3.5.3
mdit-py-plugins               0.3.3
mdurl                         0.1.2
moreorless                    0.4.0
mpmath                        1.3.0
multidict                     6.0.4
multiprocess                  0.70.14
murmurhash                    1.0.9
mypy-extensions               1.0.0
myst-parser                   0.18.1
networkx                      3.1
ninja                         1.11.1
nltk                          3.8.1
nodeenv                       1.7.0
numexpr                       2.8.4
numpy                         1.24.3
nvidia-cublas-cu11            11.10.3.66
nvidia-cuda-nvrtc-cu11        11.7.99
nvidia-cuda-runtime-cu11      11.7.99
nvidia-cudnn-cu11             8.5.0.96
nvidia-ml-py                  11.525.112
oauthlib                      3.2.2
openai                        0.24.0
openapi-schema-pydantic       1.2.4
openpyxl                      3.1.2
orjson                        3.8.12
packaging                     23.0
pandas                        2.0.1
pandas-stubs                  2.0.1.230501
paramiko                      3.1.0
pathspec                      0.11.1
pathtools                     0.1.2
pathy                         0.10.1
Pillow                        9.5.0
pip                           23.0.1
platformdirs                  3.5.0
plotly                        5.14.1
pluggy                        1.0.0
portalocker                   2.7.0
pre-commit                    3.3.1
preshed                       3.0.8
protobuf                      3.19.6
psutil                        5.9.5
pyarrow                       12.0.0
pyasn1                        0.5.0
pyasn1-modules                0.3.0
pycosat                       0.6.4
pycparser                     2.21
pydantic                      1.10.7
pyDeprecate                   0.3.2
pydub                         0.25.1
Pygments                      2.15.1
PyNaCl                        1.5.0
pyOpenSSL                     23.0.0
pyparsing                     3.0.9
pyre-extensions               0.0.27
pyrsistent                    0.19.3
PySocks                       1.7.1
pytest                        7.3.1
pytest-cov                    4.0.0
python-dateutil               2.8.2
python-etcd                   0.4.5
python-multipart              0.0.6
pytorch-sphinx-theme          0.0.24           /home/work/pytorch/docs/src/pytorch-sphinx-theme
pytorch-triton                2.1.0+7d1a95b046
pytz                          2023.3
PyYAML                        6.0
regex                         2023.5.4
requests                      2.27.1
requests-oauthlib             1.3.1
responses                     0.18.0
rich                          13.3.5
rouge-score                   0.1.2
rsa                           4.9
ruamel.yaml                   0.17.21
ruamel.yaml.clib              0.2.6
safetensors                   0.3.1
scikit-build                  0.17.3
semantic-version              2.10.0
sentencepiece                 0.1.99
sentry-sdk                    1.22.1
setproctitle                  1.3.2
setuptools                    65.6.3
six                           1.16.0
smart-open                    6.3.0
smmap                         5.0.0
sniffio                       1.3.0
snowballstemmer               2.2.0
sortedcontainers              2.4.0
spacy                         3.5.2
spacy-legacy                  3.0.12
spacy-loggers                 1.0.4
Sphinx                        5.0.0
sphinx-copybutton             0.5.0
sphinx-panels                 0.4.1
sphinxcontrib-applehelp       1.0.4
sphinxcontrib-devhelp         1.0.2
sphinxcontrib-htmlhelp        2.0.1
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-katex           0.8.6
sphinxcontrib-qthelp          1.0.3
sphinxcontrib-serializinghtml 1.1.5
SQLAlchemy                    2.0.12
srsly                         2.4.6
sse-starlette                 1.5.0
starlette                     0.26.1
stdlibs                       2022.10.9
sympy                         1.11.1
tabulate                      0.9.0
tenacity                      8.2.2
tensorboard                   2.10.0
tensorboard-data-server       0.6.1
tensorboard-plugin-wit        1.8.1
thinc                         8.1.10
timm                          0.6.13
titans                        0.0.7
tokenizers                    0.13.3
toml                          0.10.2
tomli                         2.0.1
toolz                         0.12.0
torch                         1.13.1           /root/miniconda3/lib/python3.10/site-packages
torch-struct                  0.5
torchaudio                    0.13.1
torchmetrics                  0.11.4
torchrec                      0.2.0
torchvision                   0.14.1
torchx-nightly                2023.5.3
tqdm                          4.65.0
trailrunner                   1.4.0
transformers                  4.28.0.dev0
typer                         0.7.0
types-dataclasses             0.6.6
types-pytz                    2023.3.0.0
typing_extensions             4.5.0
typing-inspect                0.8.0
tzdata                        2023.3
uc-micro-py                   1.0.2
urllib3                       1.26.15
usort                         1.0.6
uvicorn                       0.22.0
virtualenv                    20.23.0
wandb                         0.15.2
wasabi                        1.1.1
wcwidth                       0.2.6
websocket-client              1.5.1
websockets                    11.0.3
Werkzeug                      2.3.3
wheel                         0.38.4
xxhash                        3.2.0
yarl                          1.9.2
zipp                          3.15.0
zstandard                     0.19.0

2 下载代码

git clone https://github.com/hpcaitech/ColossalAI.git
cd cd ColossalAI
CUDA_EXT=1 pip install .

3 模型训练

3.1 SFT(supervised fine-tuning)

3.1.1 命令

torchrun --standalone --nproc_per_node=1 train_sft.py \
    --pretrain "/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m/" \
    --model 'bloom' \
    --strategy naive \
    --log_interval 10 \
    --save_path /mnt/f/kangpengtao/study/ColossalAI/Coati-7B \
    --dataset /mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/instinwild.json \
    --batch_size 1 \
    --accumulation_steps 8 \
    --lr 2e-5 \
    --max_datasets_size 16384 \
    --max_epochs 1 \
    --lora_rank 16 \

3.1.2 日志

root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_sft_bloom_kpt.sh
[05/16/23 15:54:24] INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:121 __init__
                    INFO     colossalai - colossalai - INFO: Loading data...
[05/16/23 15:54:25] INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:123 __init__
                    INFO     colossalai - colossalai - INFO: Loaded 103695 examples.
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:126 __init__
                    INFO     colossalai - colossalai - INFO: Limiting dataset to 16384 examples.
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:129 __init__
                    INFO     colossalai - colossalai - INFO: Formatting inputs...
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:137 __init__
                    INFO     colossalai - colossalai - INFO: Tokenizing inputs... This may take some time...
steps:   0%|                                                                                                                       | 0/2048 [00:00<?, ?it/s][05/16/23 15:54:40] WARNING  colossalai - colossalai - WARNING: /root/miniconda3/lib/python3.10/site-packages/coati/trainer/sft.py:86 fit
                    WARNING  colossalai - colossalai - WARNING: batch_id:7, abnormal loss: 2.74609375
steps: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2048/2048 [36:00<00:00,  1.05s/it]

3.2 训练奖励模型(Training reward model)

3.2.1 命令

set_n_least_used_CUDA_VISIBLE_DEVICES() {
    local n=${1:-"9999"}
    echo "GPU Memory Usage:"
    local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
        | tail -n +2 \
        | nl -v 0 \
        | tee /dev/tty \
        | sort -g -k 2 \
        | awk '{print $1}' \
        | head -n $n)
    export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
    echo "Now CUDA_VISIBLE_DEVICES is set to:"
    echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}

set_n_least_used_CUDA_VISIBLE_DEVICES 1

torchrun --standalone --nproc_per_node=1 train_reward_model.py \
   --pretrain  '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B' \
   --model 'bloom' \
   --strategy naive \
   --loss_fn 'log_sig'\
   --dataset 'Anthropic/hh-rlhf' \
   --save_path '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt' \
   --lora_rank 16 \
   --batch_size 1 \
   --max_len 512 \

3.2.2 日志

root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_rm_bloom_kpt.sh
GPU Memory Usage:
     0  228 MiB
Now CUDA_VISIBLE_DEVICES is set to:
CUDA_VISIBLE_DEVICES=0
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'lm_head.lora_A', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'lm_head.weight', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Found cached dataset json (/root/.cache/huggingface/datasets/json/hh-rlhf-226e0526113c616f/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 383.92it/s]
Parameter 'indices'=<generator object train.<locals>.<genexpr> at 0x7fdf71767ca0> of the transform datasets.arrow_dataset.Dataset.select couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 160800/160800 [05:40<00:00, 472.16it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1710/1710 [00:04<00:00, 422.55it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8552/8552 [00:19<00:00, 433.97it/s]
Train step of epoch 0: 100%|███████████████████████████████████████████████████████████████████| 160800/160800 [100:45:04<00:00,  2.26s/it, dist=nan, acc=0]
Train epoch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [100:45:04<00:00, 362704.33s/it]
root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples#

3.3 RL(Training model using prompts with RL)

3.3.1 命令

set_n_least_used_CUDA_VISIBLE_DEVICES() {
    local n=${1:-"9999"}
    echo "GPU Memory Usage:"
    local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
        | tail -n +2 \
        | nl -v 0 \
        | tee /dev/tty \
        | sort -g -k 2 \
        | awk '{print $1}' \
        | head -n $n)
    export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
    echo "Now CUDA_VISIBLE_DEVICES is set to:"
    echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}

set_n_least_used_CUDA_VISIBLE_DEVICES 1

# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2

# torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2

torchrun --standalone --nproc_per_node=1 train_prompts.py \
        --prompt_dataset '/mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/prompts.json' \
        --pretrain_dataset '/mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/instinwild.json' \
        --strategy naive \
        --model bloom \
        --pretrain '/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m' \
        --rm_path '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt' \
        --rm_pretrain '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B' \
        --save_path '/mnt/f/kangpengtao/study/ColossalAI/prompts-static.pt' \
        --max_epochs 1 \
        --num_episodes 1 \
        --train_batch_size 1 \
        --ptx_batch_size 1 \
        --experience_batch_size 1 \
        --lora_rank 16 \

3.3.2 日志

root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_prompts_bloom_kpt.sh
GPU Memory Usage:
     0  580 MiB
Now CUDA_VISIBLE_DEVICES is set to:
CUDA_VISIBLE_DEVICES=0
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'lm_head.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'lm_head.weight', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.17.mlp.dense_4h_to_h.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'lm_head.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'lm_head.weight', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.17.mlp.dense_4h_to_h.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[05/21/23 00:32:36] INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:30
                             __init__
                    INFO     colossalai - colossalai - INFO: Loading data...
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:32
                             __init__
                    INFO     colossalai - colossalai - INFO: Loaded 858 examples.
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:35
                             __init__
                    INFO     colossalai - colossalai - INFO: Limiting dataset to 16384 examples.
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:121
                             __init__
                    INFO     colossalai - colossalai - INFO: Loading data...
[05/21/23 00:32:37] INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:123
                             __init__
                    INFO     colossalai - colossalai - INFO: Loaded 103695 examples.
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:126
                             __init__
                    INFO     colossalai - colossalai - INFO: Limiting dataset to 16384 examples.
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:129
                             __init__
                    INFO     colossalai - colossalai - INFO: Formatting inputs...
                    INFO     colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:137
                             __init__
                    INFO     colossalai - colossalai - INFO: Tokenizing inputs... This may take some time...
Train epoch [1/1]: 100%|████████████████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.45it/s, reward=nan]
Episode [1/1]: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:20<00:00,  2.03s/it]

3.4 Inferrence(推理)

3.4.1 代码

import argparse

import torch
from coati.models.bloom import BLOOMActor
from coati.models.gpt import GPTActor
from coati.models.opt import OPTActor
from coati.models.roberta import RoBERTaActor
from transformers import AutoTokenizer, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

import gradio as gr

MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2

# 这里换成自己模型的路径
model_path_dict = {
    'SFT': '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B/pytorch_model.bin',
    'RM': '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt',
    'RL': '/mnt/f/kangpengtao/study/ColossalAI/prompts-static.pt',
}

pretrain_dict = {
    'bloom': '/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m'
}

def predict(model, dict, input, max_length, history):
    pretrain = pretrain_dict[model]
    updates = []
    # configure model
    if model == 'gpt2':
        actor = GPTActor(pretrained=pretrain).to(torch.cuda.current_device())
    elif model == 'bloom':
        actor = BLOOMActor(pretrained=pretrain).to(torch.cuda.current_device())
    elif model == 'opt':
        actor = OPTActor(pretrained=pretrain).to(torch.cuda.current_device())
    elif model == 'roberta':
        actor = RoBERTaActor(pretrained=pretrain).to(torch.cuda.current_device())
    else:
        raise ValueError(f'Unsupported model "{model}"')

    state_dict = torch.load(model_path_dict[dict])
    actor.model.load_state_dict(state_dict, strict=False)

    # configure tokenizer
    if model == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        tokenizer.pad_token = tokenizer.eos_token
    elif model == 'bloom':
        tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
        tokenizer.pad_token = tokenizer.eos_token
    elif model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
    elif model == 'roberta':
        tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    else:
        raise ValueError(f'Unsupported model "{model}"')

    actor.eval()
    question = f'Question: {input} ? Answer:'
    input_ids = tokenizer.encode(question, return_tensors='pt').to(torch.cuda.current_device())
    outputs = actor.generate(input_ids,
                             max_length=max_length,
                             do_sample=True,
                             top_k=50,
                             top_p=0.95,
                             num_return_sequences=1)
    output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
    for i in history:
        if not i.get('visible'):
            continue
        print(i)
        value = i.get('value')
        updates.append(gr.update(visible=True, value=value))

    updates.append(gr.update(visible=True, value="提问:" + input))
    updates.append(gr.update(visible=True, value=f"{dict}:" + output[0].replace(question, '').replace(question.replace(' ', ''), '')))
    if len(updates) < MAX_BOXES:
        updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
    history.extend(updates)
    return [history] + updates


if __name__ == '__main__':
    with gr.Blocks() as demo:
        state = gr.State([])
        text_boxes = []

        with gr.Row():
            with gr.Column(scale=1):
                model = gr.Radio(['gpt2', 'bloom', 'opt', 'roberta'], label="model", interactive=False, value='bloom')
                dict = gr.Radio(["SFT", "RM", "RL"], label="dict", interactive=True, value='RL')
                max_length = gr.Slider(0, 512, value=100, step=1.0, label="max_length", interactive=True)
                button = gr.Button("Generate")

            with gr.Column(scale=4):
                for i in range(MAX_BOXES):
                    if i % 2 == 0:
                        text_boxes += [gr.Markdown(visible=False, label="提问:")]
                    else:
                        text_boxes += [gr.Markdown(visible=False, label="回复:")]
                input = gr.Textbox(show_label=True, placeholder="input", lines=5, label='input').style(container=False)

        button.click(predict, [model, dict, input, max_length, state], [state] + text_boxes)
    demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0')

3.4.2 演示

提问:了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。

RL:Question: 了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。? Answer: 无论你是单身、独居或已婚,不管你和谁在一起,或者无论在哪里工作,都必须向世界发出一个邀请:去帮助他人!? Excluded Subjects: 婚姻和性关系、情感、女性、单身和女性伴侣、同性恋和异性伴侣。 作者简介 作者:(英国)戴维·巴兰森(David Baker) 译

提问:了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。

RL:Question: 了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。? Answer: 无论你是单身、独居或已婚,不管你和谁在一起,或者无论在哪里工作,都必须向世界发出一个邀请:去帮助他人!? Excluded Subjects: 婚姻和性关系、情感、女性、单身和女性伴侣、同性恋和异性伴侣。 作者简介 作者:(英国)戴维·巴兰森(David Baker) 译

提问:在机器学习中,什么是监督学习和无监督学习?你如何选择适当的算法?

RL:Question: 在机器学习中,什么是监督学习和无监督学习?你如何选择适当的算法?? Answer: 机器学习的深度学习方法无监督学习(Deep Learning of No-Layered):它比传统深度学习(深度学习)更具灵活性,例如,您可以通过训练网络学习某种类型(非学习任务)的模式(比如,网络训练了某条路径),然后在网络中生成新的特征,以便获得更高的模型效果(对于非学习任务,模型可能非常糟糕

提问:给定一组JSON对象,写一个程序来计算其中所有数字值的平均值。

RL:Question: 给定一组JSON对象,写一个程序来计算其中所有数字值的平均值。? Answer: 返回每一个数字的平均值。? Answer: 如果你有一个数组中的每一列都是一个数字,那么可以使用这种方法计算每一个数字的平均值。? Answer: 默认情况下,JSON会返回一个返回一个数组。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,则应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer

4 参考

https://blog.csdn.net/chen_hao_181/article/details/130172096

注意:md转公众号的连接:https://md.guozh.net/

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值