UniCRS实验记录

UniCRS实验记录

项目地址: https://github.com/RUCAIBox/UniCRS

实验设备: Linux服务器+RTX3090(24GB)

数据集redialinspire数据集实验方式相同,此处不再赘述。)

anaconda依赖: 一些不太好装的包用下列命令,所有依赖包放在文章最后。

pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-cluster==1.5.9 torch-geometric==2.0.1 torch-scatter==2.0.8 torch-sparse==0.6.12 torch-spline-conv==1.2.1 -f https://data.pyg.org/whl/torch-1.8.1+cu111.html

1.下载DBpedia数据

下载DBpedia数据后解压放在项目data/dbpedia目录下。

2.数据预处理

cd data
python dbpedia/extract_kg.py

# redial
python redial/extract_subkg.py
python redial/remove_entity.py

# inspired
python inspired/extract_subkg.py
python inspired/remove_entity.py

3.Prompt Pre-training

此步骤需要使用Hugging Face实现的模型,因其官网访问困难,这里使用 hf-mirror.com

pip install -U huggingface_hub
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download --resume-download microsoft/DialoGPT-small --local-dir ./pretrain_modelDialoGPT-small #your own path
huggingface-cli download --resume-download roberta-base --local-dir ./pretrain_model/roberta-base #your own path

下载好模型后不要忘记运行一次accelerate config命令配置基本信息,之后直接开train。

cp -r data/redial src/data/
cd src
python data/redial/process.py
accelerate launch --main_process_port 29501 train_pre.py \
    --dataset redial \  # [redial, inspired]
    --tokenizer ./pretrain_modelDialoGPT-small \ #(刚刚下载的模型)
    --model ./pretrain_modelDialoGPT-small \     #(刚刚下载的模型)
    --text_tokenizer ./pretrain_model/roberta-base \  #(刚刚下载的模型)
    --text_encoder ./pretrain_model/roberta-base \    #(刚刚下载的模型)
    --num_train_epochs 5 \
    --gradient_accumulation_steps 1 \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 128 \
    --num_warmup_steps 1389 \  # 168 for inspired
    --max_length 200 \
    --prompt_max_length 200 \
    --entity_max_length 32 \
    --learning_rate 5e-4 \  # 6e-4 for inspired
    --output_dir /path/to/pre-trained  \  # set your own save path(训练好模型的存放位置)
    --use_wandb \  # if you do not want to use wandb, comment it and the lines below
    --project crs-prompt-pre \  # wandb project name
    --name xxx  # wandb experiment name

--main_process_port 29501当29500端口被占用时,使用此参数。

4.Conversation Task Training and Inference

4.1 Conversation Task

此处作者给的代码有BUG,具体来说在utils.py文件中,当使用混合精度训练时,对item的维度将调整为8的倍数。原始代码将t向下取整为8的倍数,导致新赋值后t可能小于原始的t,因此改为向上取整为8的倍数。

# 报错信息
Traceback (most recent call last):
  File "train_conv.py", line 267, in <module>
    for step, batch in enumerate(train_dataloader):
  File "/home/anaconda3/envs/UniCRS/lib/python3.8/site-packages/accelerate/data_loader.py", line 303, in iter_for
    batch in super().__iter__()
  File "/home/anaconda3/envs/UniCRS/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/anaconda3/envs/UniCRS/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
    data = self.dataset_fetcher.fetch(index)
  # may raise StopIteration
  File "/home/anaconda3/envs/UniCRS/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
  File "/home/anaconda3/envs/UniCRS/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/code/UniCRS/UniCRS-main/src/dataset_conv.py", line 200, in __call__
    entity_batch = padded_tensor(
  File "/home/code/UniCRS/UniCRS-main/src/utils.py", line 55, in padded_tensor
    output[i, :length] = item
RuntimeError: The expanded size of the tensor(16) must match the existing size(21) at non-singleton dimension 0. Target sizes:[16]. Tensor sizes:[21]
#修改前
if use_amp:
	t = t // 8 * 8 
#修改后
if use_amp:
	t = math.ceil(t / 8) * 8

修改完代码后开train,。

cp -r data/redial src/data/
cd src
python data/redial/process_mask.py
accelerate launch train_conv.py \
    --dataset redial \  # [redial, inspired]
    --tokenizer ./pretrain_modelDialoGPT-small \ 
    --model ./pretrain_modelDialoGPT-small \     
    --text_tokenizer ./pretrain_model/roberta-base \  
    --text_encoder ./pretrain_model/roberta-base \   
    --n_prefix_conv 20 \  
    --prompt_encoder  /path/to/pre-trained/final \  # set to your save path of the pre-trained prompt
    --num_train_epochs 10 \
    --gradient_accumulation_steps 1 \
    --ignore_pad_token_for_loss \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 16 \
    --num_warmup_steps 6345 \  # 976 for inspired
    --context_max_length 200 \
    --resp_max_length 183 \
    --prompt_max_length 200 \
    --entity_max_length 32 \
    --learning_rate 1e-4 \
    --output_dir ./model/prompt/dialogpt_prompt-pre_prefix-20_redial_1e-4  # set your own save path
    --use_wandb \  # if you do not want to use wandb, comment it and the lines below
    --project crs-prompt-conv \  # wandb project name
    --name xxx  # wandb experiment name
4.2 Inference Task
#生成训练数据
accelerate launch --main_process_port 29501 infer_conv.py --dataset redial --split train --tokenizer ./pretrain_model/DialoGPT-small --model ./pretrain_model/DialoGPT-small --text_tokenizer ./pretrain_model/roberta-base --text_encoder ./pretrain_model/roberta-base --n_prefix_conv 20 --prompt_encoder  ./model/prompt/dialogpt_prompt-pre_prefix-20_redial_1e-4/final --per_device_eval_batch_size 64 --context_max_length 200 --resp_max_length 183 --prompt_max_length 200 --entity_max_length 32
#生成验证是数据
accelerate launch --main_process_port 29501 infer_conv.py --dataset redial --split valid --tokenizer ./pretrain_model/DialoGPT-small --model ./pretrain_model/DialoGPT-small --text_tokenizer ./pretrain_model/roberta-base --text_encoder ./pretrain_model/roberta-base --n_prefix_conv 20 --prompt_encoder  ./model/prompt/dialogpt_prompt-pre_prefix-20_redial_1e-4/final --per_device_eval_batch_size 64 --context_max_length 200 --resp_max_length 183 --prompt_max_length 200 --entity_max_length 32
#生成测试数据
accelerate launch --main_process_port 29501 infer_conv.py --dataset redial --split test --tokenizer ./pretrain_model/DialoGPT-small --model ./pretrain_model/DialoGPT-small --text_tokenizer ./pretrain_model/roberta-base --text_encoder ./pretrain_model/roberta-base --n_prefix_conv 20 --prompt_encoder  ./model/prompt/dialogpt_prompt-pre_prefix-20_redial_1e-4/final --per_device_eval_batch_size 64 --context_max_length 200 --resp_max_length 183 --prompt_max_length 200 --entity_max_length 32

5.Recommendation Task

首先修改src/data/redial_gen/merge.py文件,dataset ='redial'

parser = ArgumentParser()
parser.add_argument("--gen_file_prefix", type=str, required=True)
args = parser.parse_args()
gen_file_prefix = args.gen_file_prefix
dataset = 'inspired'

将对话中的推断结果合并。

cd src
cp -r data/redial/. data/redial_gen/
python data/redial_gen/merge.py --gen_file_prefix dialogpt_prompt-pre_prefix-20_redial_1e-4

训练推荐模块

accelerate launch train_conv.py --dataset redial --tokenizer ./pretrain_model/DialoGPT-small  --model ./pretrain_model/DialoGPT-small  --text_tokenizer ./pretrain_model/roberta-base --text_encoder ./pretrain_model/roberta-base --n_prefix_conv 20 --prompt_encoder ./model/prompt_encoder/final --num_train_epochs 10 --gradient_accumulation_steps 1 --ignore_pad_token_for_loss --per_device_train_batch_size 8 --per_device_eval_batch_size 16 --num_warmup_steps 6345 --context_max_length 200 --resp_max_length 183 --prompt_max_length 200 --entity_max_length 32 --learning_rate 1e-4 --output_dir ./model/prompt/dialogpt_prompt-pre_prefix-20_redial_1e-4 --use_wandb --project UniCRS --name dialogpt_prompt-pre_prefix-20_redial_1e-4 --log_all

实验结果:

请添加图片描述

wandb: Run summary:
wandb:           epoch 4
wandb:            loss 4.7062
wandb:       test/loss 7.56063
wandb:      test/mrr@1 0.04557
wandb:     test/mrr@10 0.08932
wandb:     test/mrr@50 0.09801
wandb:     test/ndcg@1 0.04557
wandb:    test/ndcg@10 0.11898
wandb:    test/ndcg@50 0.16031
wandb:   test/recall@1 0.04557
wandb:  test/recall@10 0.21646
wandb:  test/recall@50 0.40506
wandb:      valid/loss 7.05796
wandb:     valid/mrr@1 0.04843
wandb:    valid/mrr@10 0.09765
wandb:    valid/mrr@50 0.10872
wandb:    valid/ndcg@1 0.04843
wandb:   valid/ndcg@10 0.13076
wandb:   valid/ndcg@50 0.18214
wandb:  valid/recall@1 0.04843
wandb: valid/recall@10 0.23945
wandb: valid/recall@50 0.47132

requirements

accelerate==0.8.0
aiohappyeyeballs==2.3.4
aiohttp==3.10.1
aiosignal==1.3.1
async-timeout==4.0.3
attrs==24.1.0
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
docker-pycreds==0.4.0
filelock==3.15.4
frozenlist==1.4.1
fsspec==2024.6.1
gitdb==4.0.11
GitPython==3.1.43
googledrivedownloader==0.4
huggingface-hub==0.24.5
idna==3.7
isodate==0.6.1
Jinja2==3.1.4
joblib==1.4.2
loguru==0.7.2
MarkupSafe==2.1.5
multidict==6.0.5
networkx==3.1
nltk==3.8.1
numpy==1.24.4
packaging==24.1
pandas==2.0.3
pillow==10.4.0
platformdirs==4.2.2
protobuf==5.27.3
psutil==6.0.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
rdflib==7.0.0
regex==2024.7.24
requests==2.32.3
sacremoses==0.1.1
scikit-learn==1.3.2
scipy==1.10.1
sentry-sdk==2.12.0
setproctitle==1.3.3
six==1.16.0
smmap==5.0.1
threadpoolctl==3.5.0
tokenizers==0.10.3
torch==1.8.1+cu111
torch-cluster==1.5.9
torch-geometric==2.0.1
torch-scatter==2.0.8
torch-sparse==0.6.12
torch-spline-conv==1.2.1
torchaudio==0.8.1
torchvision==0.9.1+cu111
tqdm==4.66.5
transformers==4.15.0
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
wandb==0.17.5
yacs==0.1.8
yarl==1.9.4
torch-geometric==2.0.1
torch-scatter==2.0.8
torch-sparse==0.6.12
torch-spline-conv==1.2.1
torchaudio==0.8.1
torchvision==0.9.1+cu111
tqdm==4.66.5
transformers==4.15.0
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
wandb==0.17.5
yacs==0.1.8
yarl==1.9.4
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值