Colosal-AI复现流程
1 环境搭建
1.1 cuda环境
root@LAPTOP-3SUHS40U:/home/work/ColossalAI# lsb_release -a
LSB Version: core-11.1.0ubuntu2-noarch:security-11.1.0ubuntu2-noarch
Distributor ID: Ubuntu
Description: Ubuntu 20.04.6 LTS
Release: 20.04
Codename: focal
root@LAPTOP-3SUHS40U:/home/work/ColossalAI# nvidia-smi
Sun May 21 05:56:52 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.50 Driver Version: 531.79 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 2060 On | 00000000:01:00.0 On | N/A |
| N/A 47C P8 7W / N/A| 5760MiB / 6144MiB | 10% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 23 G /Xwayland N/A |
| 0 N/A N/A 306013 C /python3.10 N/A |
+---------------------------------------------------------------------------------------+
1.2 python环境
root@LAPTOP-3SUHS40U:/home/work/ColossalAI# python --version
Python 3.10.10
1.3 python package 环境
root@LAPTOP-3SUHS40U:/home/work/ColossalAI# pip list
Package Version Editable project location
----------------------------- ---------------- ------------------------------------------------
absl-py 1.4.0
aiofiles 23.1.0
aiohttp 3.8.4
aiosignal 1.3.1
alabaster 0.7.13
altair 5.0.0
anyio 3.6.2
appdirs 1.4.4
arrow 1.2.3
astunparse 1.6.3
async-timeout 4.0.2
attrs 23.1.0
Babel 2.12.1
bcrypt 4.0.1
benepar 0.2.0
blessed 1.20.0
blis 0.7.9
boltons 23.0.0
brotlipy 0.7.0
cachetools 5.3.0
catalogue 2.0.8
certifi 2022.12.7
cffi 1.15.1
cfgv 3.3.1
charset-normalizer 2.0.4
click 8.1.3
cmake 3.26.3
coati 1.0.0
colossalai 0.2.8
conda 23.3.1
conda-content-trust 0.1.3
conda-package-handling 2.0.2
conda_package_streaming 0.7.0
confection 0.0.4
contexttimer 0.3.3
coverage 7.2.5
cryptography 39.0.1
cycler 0.11.0
cymem 2.0.7
Cython 0.29.34
dataclasses-json 0.5.7
datasets 2.12.0
diffusers 0.16.1
dill 0.3.6
distlib 0.3.6
distro 1.8.0
dnspython 2.3.0
docker 6.0.1
docker-pycreds 0.4.0
docstring-parser 0.8.1
docutils 0.18.1
einops 0.6.1
et-xmlfile 1.1.0
exceptiongroup 1.1.1
expecttest 0.1.4
fabric 3.0.1
fastapi 0.95.1
fbgemm-gpu 0.4.1
ffmpy 0.3.0
filelock 3.12.0
flash-attn 0.1
fonttools 4.39.3
frozenlist 1.3.3
fsspec 2023.4.0
gitdb 4.0.10
GitPython 3.1.31
google-auth 2.17.3
google-auth-oauthlib 0.4.6
gpustat 1.1
gradio 3.30.0
gradio_client 0.2.4
greenlet 2.0.2
grpcio 1.54.0
h11 0.14.0
httpcore 0.17.0
httpx 0.24.0
huggingface-hub 0.14.1
hypothesis 6.75.1
identify 2.5.23
idna 3.4
imagesize 1.4.1
importlib-metadata 6.6.0
iniconfig 2.0.0
invoke 2.1.1
iopath 0.1.10
Jinja2 3.1.2
joblib 1.2.0
jsonpatch 1.32
jsonpointer 2.1
jsonschema 4.17.3
kiwisolver 1.4.4
langchain 0.0.161
langcodes 3.3.0
libcst 0.4.9
linkify-it-py 2.0.2
loralib 0.1.1
Markdown 3.4.3
markdown-it-py 2.2.0
MarkupSafe 2.1.2
marshmallow 3.19.0
marshmallow-enum 1.5.1
matplotlib 3.5.3
mdit-py-plugins 0.3.3
mdurl 0.1.2
moreorless 0.4.0
mpmath 1.3.0
multidict 6.0.4
multiprocess 0.70.14
murmurhash 1.0.9
mypy-extensions 1.0.0
myst-parser 0.18.1
networkx 3.1
ninja 1.11.1
nltk 3.8.1
nodeenv 1.7.0
numexpr 2.8.4
numpy 1.24.3
nvidia-cublas-cu11 11.10.3.66
nvidia-cuda-nvrtc-cu11 11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11 8.5.0.96
nvidia-ml-py 11.525.112
oauthlib 3.2.2
openai 0.24.0
openapi-schema-pydantic 1.2.4
openpyxl 3.1.2
orjson 3.8.12
packaging 23.0
pandas 2.0.1
pandas-stubs 2.0.1.230501
paramiko 3.1.0
pathspec 0.11.1
pathtools 0.1.2
pathy 0.10.1
Pillow 9.5.0
pip 23.0.1
platformdirs 3.5.0
plotly 5.14.1
pluggy 1.0.0
portalocker 2.7.0
pre-commit 3.3.1
preshed 3.0.8
protobuf 3.19.6
psutil 5.9.5
pyarrow 12.0.0
pyasn1 0.5.0
pyasn1-modules 0.3.0
pycosat 0.6.4
pycparser 2.21
pydantic 1.10.7
pyDeprecate 0.3.2
pydub 0.25.1
Pygments 2.15.1
PyNaCl 1.5.0
pyOpenSSL 23.0.0
pyparsing 3.0.9
pyre-extensions 0.0.27
pyrsistent 0.19.3
PySocks 1.7.1
pytest 7.3.1
pytest-cov 4.0.0
python-dateutil 2.8.2
python-etcd 0.4.5
python-multipart 0.0.6
pytorch-sphinx-theme 0.0.24 /home/work/pytorch/docs/src/pytorch-sphinx-theme
pytorch-triton 2.1.0+7d1a95b046
pytz 2023.3
PyYAML 6.0
regex 2023.5.4
requests 2.27.1
requests-oauthlib 1.3.1
responses 0.18.0
rich 13.3.5
rouge-score 0.1.2
rsa 4.9
ruamel.yaml 0.17.21
ruamel.yaml.clib 0.2.6
safetensors 0.3.1
scikit-build 0.17.3
semantic-version 2.10.0
sentencepiece 0.1.99
sentry-sdk 1.22.1
setproctitle 1.3.2
setuptools 65.6.3
six 1.16.0
smart-open 6.3.0
smmap 5.0.0
sniffio 1.3.0
snowballstemmer 2.2.0
sortedcontainers 2.4.0
spacy 3.5.2
spacy-legacy 3.0.12
spacy-loggers 1.0.4
Sphinx 5.0.0
sphinx-copybutton 0.5.0
sphinx-panels 0.4.1
sphinxcontrib-applehelp 1.0.4
sphinxcontrib-devhelp 1.0.2
sphinxcontrib-htmlhelp 2.0.1
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-katex 0.8.6
sphinxcontrib-qthelp 1.0.3
sphinxcontrib-serializinghtml 1.1.5
SQLAlchemy 2.0.12
srsly 2.4.6
sse-starlette 1.5.0
starlette 0.26.1
stdlibs 2022.10.9
sympy 1.11.1
tabulate 0.9.0
tenacity 8.2.2
tensorboard 2.10.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
thinc 8.1.10
timm 0.6.13
titans 0.0.7
tokenizers 0.13.3
toml 0.10.2
tomli 2.0.1
toolz 0.12.0
torch 1.13.1 /root/miniconda3/lib/python3.10/site-packages
torch-struct 0.5
torchaudio 0.13.1
torchmetrics 0.11.4
torchrec 0.2.0
torchvision 0.14.1
torchx-nightly 2023.5.3
tqdm 4.65.0
trailrunner 1.4.0
transformers 4.28.0.dev0
typer 0.7.0
types-dataclasses 0.6.6
types-pytz 2023.3.0.0
typing_extensions 4.5.0
typing-inspect 0.8.0
tzdata 2023.3
uc-micro-py 1.0.2
urllib3 1.26.15
usort 1.0.6
uvicorn 0.22.0
virtualenv 20.23.0
wandb 0.15.2
wasabi 1.1.1
wcwidth 0.2.6
websocket-client 1.5.1
websockets 11.0.3
Werkzeug 2.3.3
wheel 0.38.4
xxhash 3.2.0
yarl 1.9.2
zipp 3.15.0
zstandard 0.19.0
2 下载代码
git clone https://github.com/hpcaitech/ColossalAI.git
cd cd ColossalAI
CUDA_EXT=1 pip install .
3 模型训练
3.1 SFT(supervised fine-tuning)
3.1.1 命令
torchrun --standalone --nproc_per_node=1 train_sft.py \
--pretrain "/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m/" \
--model 'bloom' \
--strategy naive \
--log_interval 10 \
--save_path /mnt/f/kangpengtao/study/ColossalAI/Coati-7B \
--dataset /mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/instinwild.json \
--batch_size 1 \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 16384 \
--max_epochs 1 \
--lora_rank 16 \
3.1.2 日志
root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_sft_bloom_kpt.sh
[05/16/23 15:54:24] INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:121 __init__
INFO colossalai - colossalai - INFO: Loading data...
[05/16/23 15:54:25] INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:123 __init__
INFO colossalai - colossalai - INFO: Loaded 103695 examples.
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:126 __init__
INFO colossalai - colossalai - INFO: Limiting dataset to 16384 examples.
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:129 __init__
INFO colossalai - colossalai - INFO: Formatting inputs...
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:137 __init__
INFO colossalai - colossalai - INFO: Tokenizing inputs... This may take some time...
steps: 0%| | 0/2048 [00:00<?, ?it/s][05/16/23 15:54:40] WARNING colossalai - colossalai - WARNING: /root/miniconda3/lib/python3.10/site-packages/coati/trainer/sft.py:86 fit
WARNING colossalai - colossalai - WARNING: batch_id:7, abnormal loss: 2.74609375
steps: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2048/2048 [36:00<00:00, 1.05s/it]
3.2 训练奖励模型(Training reward model)
3.2.1 命令
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
torchrun --standalone --nproc_per_node=1 train_reward_model.py \
--pretrain '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B' \
--model 'bloom' \
--strategy naive \
--loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' \
--save_path '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt' \
--lora_rank 16 \
--batch_size 1 \
--max_len 512 \
3.2.2 日志
root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_rm_bloom_kpt.sh
GPU Memory Usage:
0 228 MiB
Now CUDA_VISIBLE_DEVICES is set to:
CUDA_VISIBLE_DEVICES=0
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'lm_head.lora_A', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'lm_head.weight', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Found cached dataset json (/root/.cache/huggingface/datasets/json/hh-rlhf-226e0526113c616f/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 383.92it/s]
Parameter 'indices'=<generator object train.<locals>.<genexpr> at 0x7fdf71767ca0> of the transform datasets.arrow_dataset.Dataset.select couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 160800/160800 [05:40<00:00, 472.16it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1710/1710 [00:04<00:00, 422.55it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8552/8552 [00:19<00:00, 433.97it/s]
Train step of epoch 0: 100%|███████████████████████████████████████████████████████████████████| 160800/160800 [100:45:04<00:00, 2.26s/it, dist=nan, acc=0]
Train epoch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [100:45:04<00:00, 362704.33s/it]
root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples#
3.3 RL(Training model using prompts with RL)
3.3.1 命令
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
# torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=1 train_prompts.py \
--prompt_dataset '/mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/prompts.json' \
--pretrain_dataset '/mnt/f/kangpengtao/study/ColossalAI/InstructionWild/data/instinwild.json' \
--strategy naive \
--model bloom \
--pretrain '/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m' \
--rm_path '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt' \
--rm_pretrain '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B' \
--save_path '/mnt/f/kangpengtao/study/ColossalAI/prompts-static.pt' \
--max_epochs 1 \
--num_episodes 1 \
--train_batch_size 1 \
--ptx_batch_size 1 \
--experience_batch_size 1 \
--lora_rank 16 \
3.3.2 日志
root@LAPTOP-3SUHS40U:/home/work/ColossalAI/applications/Chat/examples# ./train_prompts_bloom_kpt.sh
GPU Memory Usage:
0 580 MiB
Now CUDA_VISIBLE_DEVICES is set to:
CUDA_VISIBLE_DEVICES=0
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'lm_head.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'lm_head.weight', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.17.mlp.dense_4h_to_h.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at /mnt/f/kangpengtao/study/ColossalAI/Coati-7B were not used when initializing BloomModel: ['transformer.h.5.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_A', 'transformer.h.15.self_attention.query_key_value.lora_B', 'transformer.h.6.self_attention.query_key_value.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_A', 'transformer.h.22.self_attention.query_key_value.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_A', 'transformer.h.3.self_attention.dense.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_B', 'transformer.h.1.self_attention.dense.lora_B', 'transformer.h.2.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.self_attention.query_key_value.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_B', 'transformer.h.13.mlp.dense_4h_to_h.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.self_attention.dense.lora_A', 'transformer.h.3.self_attention.dense.lora_B', 'transformer.h.21.self_attention.dense.lora_A', 'transformer.h.7.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.query_key_value.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_A', 'transformer.h.14.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_h_to_4h.lora_A', 'transformer.h.15.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.mlp.dense_h_to_4h.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_B', 'transformer.h.2.self_attention.query_key_value.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_A', 'transformer.h.18.self_attention.dense.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.mlp.dense_h_to_4h.lora_A', 'transformer.h.18.self_attention.dense.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.self_attention.dense.lora_B', 'transformer.h.11.self_attention.query_key_value.lora_A', 'transformer.h.5.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_A', 'transformer.h.21.self_attention.dense.lora_B', 'transformer.h.18.self_attention.query_key_value.lora_B', 'transformer.h.2.self_attention.dense.lora_B', 'transformer.h.5.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.mlp.dense_4h_to_h.lora_B', 'transformer.h.13.self_attention.dense.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_A', 'transformer.h.19.self_attention.dense.lora_A', 'transformer.h.13.mlp.dense_h_to_4h.lora_B', 'transformer.h.15.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.self_attention.dense.lora_A', 'lm_head.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_A', 'transformer.h.16.self_attention.query_key_value.lora_B', 'transformer.h.7.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_A', 'transformer.h.11.self_attention.query_key_value.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_A', 'transformer.h.10.mlp.dense_4h_to_h.lora_A', 'transformer.h.21.mlp.dense_h_to_4h.lora_B', 'transformer.h.17.self_attention.query_key_value.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.query_key_value.lora_B', 'transformer.h.9.self_attention.dense.lora_B', 'transformer.h.22.mlp.dense_4h_to_h.lora_A', 'transformer.h.19.mlp.dense_4h_to_h.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_A', 'transformer.h.10.self_attention.query_key_value.lora_A', 'transformer.h.16.self_attention.dense.lora_A', 'transformer.h.4.mlp.dense_h_to_4h.lora_A', 'transformer.h.0.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.mlp.dense_4h_to_h.lora_A', 'transformer.h.1.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.self_attention.dense.lora_B', 'transformer.h.15.self_attention.query_key_value.lora_A', 'transformer.h.0.self_attention.query_key_value.lora_A', 'transformer.h.0.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_4h_to_h.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_A', 'transformer.h.6.self_attention.dense.lora_A', 'transformer.h.11.self_attention.dense.lora_B', 'transformer.h.17.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_A', 'transformer.h.12.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_B', 'transformer.h.16.self_attention.dense.lora_B', 'transformer.h.15.self_attention.dense.lora_B', 'transformer.h.10.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_A', 'transformer.h.1.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_A', 'transformer.h.12.mlp.dense_4h_to_h.lora_A', 'transformer.h.3.mlp.dense_4h_to_h.lora_A', 'transformer.h.2.self_attention.query_key_value.lora_B', 'transformer.h.12.self_attention.dense.lora_B', 'transformer.h.18.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.self_attention.dense.lora_B', 'transformer.h.16.self_attention.query_key_value.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_B', 'transformer.h.7.mlp.dense_4h_to_h.lora_A', 'transformer.h.23.self_attention.dense.lora_B', 'transformer.h.14.mlp.dense_4h_to_h.lora_B', 'transformer.h.7.mlp.dense_h_to_4h.lora_B', 'transformer.h.8.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_h_to_4h.lora_A', 'transformer.h.14.self_attention.dense.lora_B', 'transformer.h.1.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_B', 'transformer.h.8.mlp.dense_h_to_4h.lora_A', 'transformer.h.2.mlp.dense_4h_to_h.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_4h_to_h.lora_B', 'transformer.h.0.self_attention.query_key_value.lora_B', 'transformer.h.5.self_attention.dense.lora_A', 'transformer.h.2.mlp.dense_h_to_4h.lora_B', 'transformer.h.10.self_attention.dense.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_A', 'transformer.h.14.self_attention.query_key_value.lora_B', 'transformer.h.8.self_attention.query_key_value.lora_A', 'transformer.h.18.self_attention.query_key_value.lora_A', 'transformer.h.4.mlp.dense_4h_to_h.lora_A', 'transformer.h.4.self_attention.query_key_value.lora_B', 'transformer.h.4.self_attention.dense.lora_B', 'transformer.h.21.self_attention.query_key_value.lora_A', 'transformer.h.7.self_attention.query_key_value.lora_B', 'transformer.h.20.self_attention.dense.lora_A', 'transformer.h.18.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.mlp.dense_h_to_4h.lora_A', 'transformer.h.9.self_attention.query_key_value.lora_A', 'transformer.h.11.mlp.dense_h_to_4h.lora_B', 'transformer.h.5.self_attention.dense.lora_B', 'transformer.h.3.self_attention.query_key_value.lora_B', 'transformer.h.11.self_attention.dense.lora_A', 'transformer.h.23.self_attention.query_key_value.lora_A', 'transformer.h.23.mlp.dense_h_to_4h.lora_B', 'transformer.h.20.self_attention.dense.lora_B', 'transformer.h.3.mlp.dense_h_to_4h.lora_B', 'transformer.h.23.mlp.dense_4h_to_h.lora_A', 'transformer.h.13.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_A', 'transformer.h.6.mlp.dense_4h_to_h.lora_B', 'transformer.h.1.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.query_key_value.lora_B', 'lm_head.lora_B', 'transformer.h.0.mlp.dense_h_to_4h.lora_A', 'transformer.h.7.mlp.dense_4h_to_h.lora_B', 'transformer.h.22.self_attention.dense.lora_B', 'transformer.h.19.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.dense.lora_A', 'transformer.h.16.mlp.dense_h_to_4h.lora_B', 'transformer.h.6.mlp.dense_h_to_4h.lora_B', 'transformer.h.22.self_attention.query_key_value.lora_A', 'lm_head.weight', 'transformer.h.22.self_attention.dense.lora_A', 'transformer.h.15.mlp.dense_4h_to_h.lora_A', 'transformer.h.15.self_attention.dense.lora_A', 'transformer.h.10.self_attention.dense.lora_B', 'transformer.h.12.mlp.dense_h_to_4h.lora_A', 'transformer.h.20.self_attention.query_key_value.lora_B', 'transformer.h.23.self_attention.dense.lora_A', 'transformer.h.22.mlp.dense_4h_to_h.lora_B', 'transformer.h.11.mlp.dense_h_to_4h.lora_A', 'transformer.h.13.self_attention.dense.lora_B', 'transformer.h.20.self_attention.query_key_value.lora_A', 'transformer.h.5.mlp.dense_4h_to_h.lora_B', 'transformer.h.4.mlp.dense_h_to_4h.lora_B', 'transformer.h.19.self_attention.query_key_value.lora_B', 'transformer.h.20.mlp.dense_h_to_4h.lora_B', 'transformer.h.4.self_attention.dense.lora_A', 'transformer.h.19.self_attention.dense.lora_B', 'transformer.h.9.mlp.dense_4h_to_h.lora_B', 'transformer.h.17.mlp.dense_h_to_4h.lora_B', 'transformer.h.14.self_attention.query_key_value.lora_A', 'transformer.h.21.mlp.dense_4h_to_h.lora_A', 'transformer.h.8.self_attention.dense.lora_A', 'transformer.h.17.mlp.dense_4h_to_h.lora_B']
- This IS expected if you are initializing BloomModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BloomModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[05/21/23 00:32:36] INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:30
__init__
INFO colossalai - colossalai - INFO: Loading data...
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:32
__init__
INFO colossalai - colossalai - INFO: Loaded 858 examples.
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/prompt_dataset.py:35
__init__
INFO colossalai - colossalai - INFO: Limiting dataset to 16384 examples.
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:121
__init__
INFO colossalai - colossalai - INFO: Loading data...
[05/21/23 00:32:37] INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:123
__init__
INFO colossalai - colossalai - INFO: Loaded 103695 examples.
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:126
__init__
INFO colossalai - colossalai - INFO: Limiting dataset to 16384 examples.
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:129
__init__
INFO colossalai - colossalai - INFO: Formatting inputs...
INFO colossalai - colossalai - INFO: /root/miniconda3/lib/python3.10/site-packages/coati/dataset/sft_dataset.py:137
__init__
INFO colossalai - colossalai - INFO: Tokenizing inputs... This may take some time...
Train epoch [1/1]: 100%|████████████████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00, 2.45it/s, reward=nan]
Episode [1/1]: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:20<00:00, 2.03s/it]
3.4 Inferrence(推理)
3.4.1 代码
import argparse
import torch
from coati.models.bloom import BLOOMActor
from coati.models.gpt import GPTActor
from coati.models.opt import OPTActor
from coati.models.roberta import RoBERTaActor
from transformers import AutoTokenizer, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
import gradio as gr
MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
# 这里换成自己模型的路径
model_path_dict = {
'SFT': '/mnt/f/kangpengtao/study/ColossalAI/Coati-7B/pytorch_model.bin',
'RM': '/mnt/f/kangpengtao/study/ColossalAI/rm-static.pt',
'RL': '/mnt/f/kangpengtao/study/ColossalAI/prompts-static.pt',
}
pretrain_dict = {
'bloom': '/mnt/f/kangpengtao/study/ColossalAI/bigscience/bloom-560m'
}
def predict(model, dict, input, max_length, history):
pretrain = pretrain_dict[model]
updates = []
# configure model
if model == 'gpt2':
actor = GPTActor(pretrained=pretrain).to(torch.cuda.current_device())
elif model == 'bloom':
actor = BLOOMActor(pretrained=pretrain).to(torch.cuda.current_device())
elif model == 'opt':
actor = OPTActor(pretrained=pretrain).to(torch.cuda.current_device())
elif model == 'roberta':
actor = RoBERTaActor(pretrained=pretrain).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{model}"')
state_dict = torch.load(model_path_dict[dict])
actor.model.load_state_dict(state_dict, strict=False)
# configure tokenizer
if model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif model == 'bloom':
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
tokenizer.pad_token = tokenizer.eos_token
elif model == 'opt':
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
elif model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
else:
raise ValueError(f'Unsupported model "{model}"')
actor.eval()
question = f'Question: {input} ? Answer:'
input_ids = tokenizer.encode(question, return_tensors='pt').to(torch.cuda.current_device())
outputs = actor.generate(input_ids,
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1)
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
for i in history:
if not i.get('visible'):
continue
print(i)
value = i.get('value')
updates.append(gr.update(visible=True, value=value))
updates.append(gr.update(visible=True, value="提问:" + input))
updates.append(gr.update(visible=True, value=f"{dict}:" + output[0].replace(question, '').replace(question.replace(' ', ''), '')))
if len(updates) < MAX_BOXES:
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
history.extend(updates)
return [history] + updates
if __name__ == '__main__':
with gr.Blocks() as demo:
state = gr.State([])
text_boxes = []
with gr.Row():
with gr.Column(scale=1):
model = gr.Radio(['gpt2', 'bloom', 'opt', 'roberta'], label="model", interactive=False, value='bloom')
dict = gr.Radio(["SFT", "RM", "RL"], label="dict", interactive=True, value='RL')
max_length = gr.Slider(0, 512, value=100, step=1.0, label="max_length", interactive=True)
button = gr.Button("Generate")
with gr.Column(scale=4):
for i in range(MAX_BOXES):
if i % 2 == 0:
text_boxes += [gr.Markdown(visible=False, label="提问:")]
else:
text_boxes += [gr.Markdown(visible=False, label="回复:")]
input = gr.Textbox(show_label=True, placeholder="input", lines=5, label='input').style(container=False)
button.click(predict, [model, dict, input, max_length, state], [state] + text_boxes)
demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0')
3.4.2 演示
提问:了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。
RL:Question: 了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。? Answer: 无论你是单身、独居或已婚,不管你和谁在一起,或者无论在哪里工作,都必须向世界发出一个邀请:去帮助他人!? Excluded Subjects: 婚姻和性关系、情感、女性、单身和女性伴侣、同性恋和异性伴侣。 作者简介 作者:(英国)戴维·巴兰森(David Baker) 译
提问:了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。
RL:Question: 了解一个无家可归或社会边缘人的故事,并考虑自己的角色和关怀方式。? Answer: 无论你是单身、独居或已婚,不管你和谁在一起,或者无论在哪里工作,都必须向世界发出一个邀请:去帮助他人!? Excluded Subjects: 婚姻和性关系、情感、女性、单身和女性伴侣、同性恋和异性伴侣。 作者简介 作者:(英国)戴维·巴兰森(David Baker) 译
提问:在机器学习中,什么是监督学习和无监督学习?你如何选择适当的算法?
RL:Question: 在机器学习中,什么是监督学习和无监督学习?你如何选择适当的算法?? Answer: 机器学习的深度学习方法无监督学习(Deep Learning of No-Layered):它比传统深度学习(深度学习)更具灵活性,例如,您可以通过训练网络学习某种类型(非学习任务)的模式(比如,网络训练了某条路径),然后在网络中生成新的特征,以便获得更高的模型效果(对于非学习任务,模型可能非常糟糕
提问:给定一组JSON对象,写一个程序来计算其中所有数字值的平均值。
RL:Question: 给定一组JSON对象,写一个程序来计算其中所有数字值的平均值。? Answer: 返回每一个数字的平均值。? Answer: 如果你有一个数组中的每一列都是一个数字,那么可以使用这种方法计算每一个数字的平均值。? Answer: 默认情况下,JSON会返回一个返回一个数组。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,则应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer: 如果该字符串的每个值都一个数字,那么JSON会返回一个数组,因为数字比较准确。但是如果你给一个字符串定义了12个属性,那么JSON会返回一个返回一个数组。这个问题就在这里。? Question: 你需要一个JSON对象来存储一组数据(一个字符串和一个集合)。例如,如果你给一个字符串定义了8个属性,那么应该使用这个对象来存储这些属性(这些属性被抽象到一个类下面)。? Answer
4 参考
https://blog.csdn.net/chen_hao_181/article/details/130172096
注意:md转公众号的连接:https://md.guozh.net/