使用RoBERT进行fine tune来复现GLUE的效果

一. 参考博客or文献

Finetuning RoBERTa on GLUE tasks

二. Proprocess GLUE task data

2.1 下载GLUE的数据集

GLUE数据集的下载链接: GLUE

import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile

TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
             "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
             "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
             "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
             "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
             "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
             "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
             "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
             "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
             "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
             "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}

MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'

def download_and_extract(task, data_dir):
    print("Downloading and extracting %s..." % task)
    data_file = "%s.zip" % task
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(data_file)
    print("\tCompleted!")

def format_mrpc(data_dir, path_to_data):
    print("Processing MRPC...")
    mrpc_dir = os.path.join(data_dir, "MRPC")
    if not os.path.isdir(mrpc_dir):
        os.mkdir(mrpc_dir)
    if path_to_data:
        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
    else:
        print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
        mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
        urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
        urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
    assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
    assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
    urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))

    dev_ids = []
    with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
        for row in ids_fh:
            dev_ids.append(row.strip().split('\t'))

    with open(mrpc_train_file, encoding="utf8") as data_fh, \
         open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
         open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
        header = data_fh.readline()
        train_fh.write(header)
        dev_fh.write(header)
        for row in data_fh:
            label, id1, id2, s1, s2 = row.strip().split('\t')
            if [id1, id2] in dev_ids:
                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
            else:
                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))

    with open(mrpc_test_file, encoding="utf8") as data_fh, \
            open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
        header = data_fh.readline()
        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
        for idx, row in enumerate(data_fh):
            label, id1, id2, s1, s2 = row.strip().split('\t')
            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
    print("\tCompleted!")

def download_diagnostic(data_dir):
    print("Downloading and extracting diagnostic...")
    if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
        os.mkdir(os.path.join(data_dir, "diagnostic"))
    data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
    urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
    print("\tCompleted!")
    return

def get_tasks(task_names):
    task_names = task_names.split(',')
    if "all" in task_names:
        tasks = TASKS
    else:
        tasks = []
        for task_name in task_names:
            assert task_name in TASKS, "Task %s not found!" % task_name
            tasks.append(task_name)
    return tasks

def main(arguments):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
    parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
                        type=str, default='all')
    parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
                        type=str, default='')
    args = parser.parse_args(arguments)

    if not os.path.isdir(args.data_dir):
        os.mkdir(args.data_dir)
    tasks = get_tasks(args.tasks)

    for task in tasks:
        if task == 'MRPC':
            format_mrpc(args.data_dir, args.path_to_mrpc)
        elif task == 'diagnostic':
            download_diagnostic(args.data_dir)
        else:
            download_and_extract(task, args.data_dir)

if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))

起到的作用是, 每个数据集从网址上下载下来, 存储在文件夹中.

2.2 预处理GLUE的数据集

if [[ $# -ne 2 ]]; then
  echo "Run as following:"
  echo "./examples/roberta/preprocess_GLUE_tasks.sh <glud_data_folder> <task_name>"
  exit 1
fi
  • ‘$#’: 是一个特殊的变量, 它表示命令行参数的数量.
  • ‘-ne’: 是一个比较运算符, 表示不等于.
  • ‘echo “Run as following” ’: epoch是一个命令, 用于在终端输出文本, 所以这段代码表的含义是: 在终端输出这句话.
  • ‘exit 1’: exit是一个内置命令, 用于退出当前的脚本或终端会话; 1是一个退出状态码, 用于指示脚本的非正常退出, 非零的退出状态码通常表示发生了错误.
TASKS=$2 # QQP
  • ‘TASKS=’: 这是一个变量赋值的语法形式, TASKS是变量名.
  • ‘$2’: 这是一个特殊的变量, 在Bash脚本中表示命令行参数的索引位置; '$2’表示命令行参数的第二个参数.

  SPLITS="train dev test"
  INPUT_COUNT=2
  if [ "$TASK" = "QQP" ]     
  then
    INPUT_COLUMNS=( 4 5 )     # train.tsv: id, qid1, qid2, question1, question2, is_duplicate
    TEST_INPUT_COLUMNS=( 2 3 )     # test.tsv: id, question1, question2
    LABEL_COLUMN=6
  elif [ "$TASK" = "MNLI" ]
  then
    SPLITS="train dev_matched dev_mismatched test_matched test_mismatched"
    INPUT_COLUMNS=( 9 10 )   # train.tsv: index, promptID,	pairID,	genre,	sentence1_binary_parse,	sentence2_binary_parse,	sentence1_parse,	sentence2_parse,	sentence1,	sentence2,	label1,	gold_label
    TEST_INPUT_COLUMNS=( 9 10 )   
    DEV_LABEL_COLUMN=16      # dev.tsv: index	promptID	pairID	genre	sentence1_binary_parse	sentence2_binary_parse	sentence1_parse	sentence2_parse	sentence1	sentence2	label1	label2	label3	label4	label5	gold_label
    LABEL_COLUMN=12
  elif [ "$TASK" = "QNLI" ]
  then
    INPUT_COLUMNS=( 2 3 )     # train.tsv: index, question, sentence, label
    TEST_INPUT_COLUMNS=( 2 3 )     
    LABEL_COLUMN=4
  elif [ "$TASK" = "MRPC" ]
  then
    INPUT_COLUMNS=( 4 5 )     # train.txt: Quality, #1 ID, #2 ID, #1 String, #2 String
    TEST_INPUT_COLUMNS=( 4 5 )
    LABEL_COLUMN=1
  elif [ "$TASK" = "RTE" ]
  then
    INPUT_COLUMNS=( 2 3 )     # train.tsv: index, sentence1, sentence2, label
    TEST_INPUT_COLUMNS=( 2 3 )
    LABEL_COLUMN=4
  elif [ "$TASK" = "STS-B" ]
  then
    INPUT_COLUMNS=( 8 9 )     # train.tsv: index, genre, filename, year, old_index, source1, source2, sentence1, sentence2, score
    TEST_INPUT_COLUMNS=( 8 9 )
    LABEL_COLUMN=10
  # Following are single sentence tasks.
  elif [ "$TASK" = "SST-2" ]
  then
    INPUT_COLUMNS=( 1 )     # train.tsv: sentece, label
    TEST_INPUT_COLUMNS=( 2 )     # test.tsv: index, sentence
    LABEL_COLUMN=2
    INPUT_COUNT=1
  elif [ "$TASK" = "CoLA" ]
  then
    INPUT_COLUMNS=( 4 )     # train.tsv: gj04, 1, *, 'The gardener watered the flowers.'
    TEST_INPUT_COLUMNS=( 2 )     # test.tsv: index, sentence
    LABEL_COLUMN=2
    INPUT_COUNT=1
  fi
  • ‘INPUT_COLUMNS’: train数据集中的features. 举例: 如果是做两个句子的关联任务, 则features有两个, 分别代表两个句子.(第几列)
  • ‘TEST_INPUT_COLUMNS’: test数据集中的features.(第几列)
  • ‘LABEL_COLUMN’: train数据集中的flag/score.(第几列)
  • ‘DEV_LABEL_COLUMN’: dev数据集中的features.(第几列)
  • ‘INPUT_COUNT’: train数据集features的列数.
  rm -rf "$TASK_DATA_FOLDER/processed"
  mkdir -p "$TASK_DATA_FOLDER/processed"
  • 删除该数据集文件夹下的processed文件/文件夹, 并创建该数据集目录下文件夹processed, 猜测用于存放处理后的数据.
  for SPLIT in $SPLITS     # SPLITS: 'train, dev, test' or 'train, dev_matched, dev_mismatched, test_matched, test_mismatched'
  do
    # CoLA train and dev doesn't have header.
    if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]]
    then
      cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
    else
      tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
    fi

    # Remove unformatted lines from train and dev files for QQP dataset.
    if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]]
    then
      awk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
    else
      cp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
    fi
    rm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
  done
  • 如果是CoLA, 且为train与dev数据集, 则直接cp到/processed/xxx.tsv.temp; 否则将/xxx.temp 的第二行开始读取到/processed/xxx.tsv.temp中. 这其中tail -n +2, 指的是从数据集的第二行开始读取数据.
  • 标准的QQP数据集一共有6列, 这里想要通过awk命令来去除掉那些非标准的数据.
  # Split into input0, input1 and label
  for SPLIT in $SPLITS
  do
    for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
    do
      if [[ "$SPLIT" != test* ]]
      then
        COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}
      else
        COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}
      fi
      cut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE";
    done

    if [[ "$SPLIT" != test* ]]
    then
      if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]
      then
        cut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv"  > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
      else
        cut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
      fi
    fi

这段代码的作用是将input, label给提取出来.其中利用了shell命令中的cut命令能够指定原文件的列到输出文件中.
所有的文件被处理到目录: xxx/processed/train.label, xxx/processed/test.raw.input0

    # BPE encode.
    for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
    do
      LANG="input$INPUT_TYPE"
      echo "BPE encoding $SPLIT/$LANG"
      python -m examples.roberta.multiprocessing_bpe_encoder \
      --encoder-json encoder.json \
      --vocab-bpe vocab.bpe \
      --inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \
      --outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \
      --workers 60 \
      --keep-empty;
    done

使用BPE编码方式, 进行编码.

  # Remove output directory.
  rm -rf "$TASK-bin"

  DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG"
  TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG"
  if [ "$TASK" = "MNLI" ]
  then
    DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG"
    TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG"
  fi

  # Run fairseq preprocessing:
  for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
  do
    LANG="input$INPUT_TYPE"
    fairseq-preprocess \
      --only-source \
      --trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \
      --validpref "${DEVPREF//LANG/$LANG}" \
      --testpref "${TESTPREF//LANG/$LANG}" \
      --destdir "$TASK-bin/$LANG" \
      --workers 60 \
      --srcdict dict.txt;
  done
  if [[ "$TASK" !=  "STS-B" ]]
  then
    fairseq-preprocess \
      --only-source \
      --trainpref "$TASK_DATA_FOLDER/processed/train.label" \
      --validpref "${DEVPREF//LANG/label}" \
      --destdir "$TASK-bin/label" \
      --workers 60;
  else
    # For STS-B output range is converted to be between: [0.0, 1.0]
    mkdir -p "$TASK-bin/label"
    awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label"
    awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label"
  fi

调用fairseq-preprocess 命令来对数据集进行最后一步的处理.

2.2.1 算法思路与整体代码以及运行结果图

  • 整体思路: 将九个glue的文件处理为vector类型的可训练文件.
  • 需要的文件:
    • glue的九个数据集: CoLA, MNLI, MRPC, QNLI, QQP, RTE, SST-2, STS-B, WNLI.(第一次测试只处处理前8个数据集)
    • 撰写 preprocess_GLUE_task.sh 脚本文件.
    • 准备好 BPE encode 的单词表以及使用BPE的python文件, 其中包括 以下文件: examples.roberta.multiprocessing_bpe_encoder, encoder.json, multiprocessing_bpe_encoder.py.
  • 细节思路:
      1. 获取各个下游任务的 input_features 以及 label 所在的列.
      1. 去除每个下游任务的head, 这里对于CoLA数据集需要特殊考虑, 因为它本身文件中就没有head.
      1. 去除QQP数据集中一些 unformatted 的 lines.
      1. 利用第一步得到的列以及2,3步得到的清洗后的数据集, 直接提取出features与label.
      1. 使用BPE文件, 对features进行encoding.
      1. 使用 fairseq-preprocess进行train, dev, test数据集的制作, 生成了bin, log与idx文件, 方便后续模型的训练.
  • 实际操作的例子:
      1. 文件路径设置, 以 all task 为例子, 假设为~/LLM/GLUE/MyProcess_Glue/preprocess_GLUE_tasks.sh data ALL.
      1. 去掉所有文件的header, 以CoLA为例子, 将 data/CoLA/train.tsv -> ‘data/CoLA/processed/train_temp.tsv’
      1. 去掉QQA数据集中的unformatted的数据line, 并它以及其他数据集存储到 'data/CoLA/processed/train.tsv’中, 将’data/CoLA/processed/train_temp.tsv’删除.
      1. 提取出文件中的 features 和 label 的columns, 这里需要注意MNLI数据集的dev与train中label所在的column是不一样的, 需要分开处理, 其他的都一样; features输出到 ‘data/CoLA/processed/train_raw_input0.tsv’, ‘data/CoLA/processed/train_raw_input1.tsv’; label输出到 ‘data/CoLA/processed/train_label.tsv’;
      1. 使用bpe进行encoding, 将encoding后的文件输出为 ‘CoLA/processed/train_input0.tsv’.
      1. 使用 fairseq-preprocess命令, 将encoding好的文件作为input, 输出对应的bin, log, idx文件, 并放到 CoLA-bin文件夹下.

2.2.2 完整代码与处理结果

完整代码

#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# This is my program to address glue dataset.

# judge the line
if [[ $# -ne 2 ]]; then
    echo "Run as following:"
    echo "~/LLM/GLUE/MyProcess_Glue/preprocess_GLUE_tasks.sh <glue_data_folder> <tssk_name>"
    exit 1
fi

# get the path of folder
GLUE_DATA_FOLDER=$1
# get the tasks of operating
TASKS=$2

# download bpe encoder.json, vocabulary and fairseq dictionary
# wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
# wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
# wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'

if [ "$TASKS" = "ALL" ]
then
    TASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA"
fi

# the starting of preprocessing tasks
for TASK in $TASKS
do
    echo "Precessing $TASK"

    # get current task's directory
    TASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK"
    echo "Raw data as download from glue directory: $TASK_DATA_FOLDER"
    
    # We will get three dataset
    SPLITS="train dev test"
    INPUT_COUNT=2 # default the number of input senteces
    # get the columns of features and label, in respect to train, dev, test.
    if [ "$TASK" = "QQP" ] # train,dev,test
    then
        INPUT_COLUMNS=( 4 5 ) # id,qid1,qid2,question1,question2,is_duplicata
        TEST_INPUT_COLUMNS=( 2 3 ) # id,question1,question2
        LABEL_COLUMN=6
    elif [ "$TASK" = "MNLI" ] # train,test_match,test_mismatch,dev_m,dev_m
    then
        SPLITS="train dev_matched dev_mismatched test_matched test_mismatched"
        INPUT_COLUMNS=( 9 10 ) # index,proptID,pairID,genre,sentence1_binary_parse,sentence2_binary_parse,sentence1_parse,sentence2_parse,sentence1,sentence2,label1,gold_label
        TEST_INPUT_COLUMNS=( 9 10 ) # index,promptID,pairID,genre,sentence1_binary_parse,sentence2_binary_parse,sentence1_parse,sentence2_parse,sentence1,sentence2
        DEV_LABEL_COLUMN=16 # index,promptID,pairID,genre,sentence1_binary_parse,sentence2_binary_parse,sentence1_parse,sentence2_parse,sentence1,sentence2,label1,label2,label3,label4,label5,gold_label
        LABEL_COLUMN=12 
    elif [ "$TASK" = "QNLI" ] # train,test,dev
    then
        INPUT_COLUMNS=( 2 3 ) # index,question,sentence,label
        TEST_INPUT_COLUMNS=( 2 3 ) # index, question, sentence
        LABEL_COLUMN=4
    elif [ "$TASK" = "MRPC" ] # train(dev), test
    then
        INPUT_COLUMNS=( 4 5 ) # Quality,1ID,2ID,1String,2String
        TEST_INPUT_COLUMNS=( 4 5 ) # Quality,1ID,2ID,1String,2String
        LABEL_COLUMN=1
    elif [ "$TASK" = "RTE" ] # train,dev,test
    then
        INPUT_COLUMNS=( 2 3 ) # index,sentence1,sentence2,label
        TEST_INPUT_COLUMNS=( 2 3 ) # index,sentence1,sentence2
        LABEL_COLUMN=4
    elif [ "$TASK" = "STS-B" ] # train,dev,test
    then
        INPUT_COLUMNS=( 8 9 ) # index,genre,filename,year,old_index,source1,source2,sentence1,sentence2,score
        TEST_INPUT_COLUMNS=( 8 9 ) # index,genre,filename,year,old_index,source1,source2,sentence1,sentence2
        LABEL_COLUMN=10
    elif [ "$TASK" = "SST-2" ] # train,dev,test
    then
        INPUT_COLUMNS=( 1 ) # sentence,label
        TEST_INPUT_COLUMNS=( 2 ) # index,sentence
        LABEL_COLUMN=2
        INPUT_COUNT=1
    elif [ "$TASK" = "CoLA" ] # train(there aren't heads),dev(too),test
    then
        INPUT_COLUMNS=( 4 ) # xxx,1,*,sentence
        TEST_INPUT_COLUMNS=( 2 )
        LABEL_COLUMN=2
        INPUT_COUNT=1
    fi 

    # mkdir a folder to save our new dataset processed
    rm -rf "$TASK_DATA_FOLDER/processed"
    mkdir -p "$TASK_DATA_FOLDER/processed"
    
    get the pointed columns from $TASK_DATA_FOLDER=$GLUE_DATA_FOLDER/$TASK
    for SPLIT in $SPLITS 
    do
        # CoLA train and dev doesn't have hdeader.
        if [[ ( "$TASK" = "GoLA" ) && ( "$SPLIT" != "test" ) ]]
        then   # CoLA's train or dev
            cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv";
        else
            tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv";
        fi

    # Remove unformatted lines from train and dev files for QQP dataset.
    if [[ ( "$TASK" = "QQP" ) && ( "$SPLIT" != "test" ) ]]
    then
        awk -F '\t' -v NUM_FILELDS=6 'NF==NUM_FILELDS{print}{}' "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
    else
        cp "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
    fi
    rm "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv";
    done
    
    # Get features and label columns, called them "input0, input1, label"
    for SPLIT in $SPLITS
    do
        # Extract features
        for INPUT_TYPE in $(seq 0 $(( INPUT_COUNT - 1 )))
        do   # process the train and dev dataset.
            if [[ "$SPLIT" != test* ]]
            then 
                COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}
            else
                COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}
            fi
            cut -f "$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/${SPLIT}_raw_input$INPUT_TYPE.tsv";
        done

        # Extract labels
        if [[ "$SPLIT" != test* ]]
        then
            if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]  # Only this dataset's dev have a different label column, in respect to train dataset
            then
                cut -f "$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/${SPLIT}_label.tsv";            
            else
                cut -f "$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/${SPLIT}_label.tsv";    
            fi
        fi


        # BPE encode
        for INPUT_TYPE in $(seq 0 $(( INPUT_COUNT - 1 )))
        do
            LANG="input$INPUT_TYPE" 
            echo "BPE encoding $SPLIT/$LANG"
            python -m multiprocessing_bpe_encoder \
            --encoder-json encoder.json \
            --vocab-bpe vocab.bpe \
            --inputs "$TASK_DATA_FOLDER/processed/${SPLIT}_raw_$LANG.tsv" \
            --outputs "$TASK_DATA_FOLDER/processed/${SPLIT}_$LANG.tsv" \
            --workers 60 \
            --keep-empty;
        done
    done

    # Remove output directory
    rm -rf "$TASK-bin"
    
    DEVPREF="$TASK_DATA_FOLDER/processed/dev_LANG"
    TESTPREF="$TASK_DATA_FOLDER/processed/test_LANG"
    if [ "$TASK" = "MNLI" ]
    then
        DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched_LANG,$TASK_DATA_FOLDER/processed/dev_mismatched_LANG"
        TESTPREF="$TASK_DATA_FOLDER/processed/test_matched_LANG,$TASK_DATA_FOLDER/processd/test_mismatched_LANG"
    fi

    # Run fairseq preprocessing:
    for INPUT_TYPE in $(seq 0 $(( INPUT_COUNT-1 )))
    do
        LANG="input$INPUT_TYPE.tsv"
        fairseq-preprocess \
            --only-source \
            --trainpref "$TASK_DATA_FOLDER/processed/train_$LANG" \
            --validpref "${DEVPREF//LANG/$LANG}" \
            --testpref "${TESTPREF//LANG/$LANG}" \
            --destdir "$TASK-bin/$LANG" \
            --workers 60 \
            --srcdict dict.txt;
    done
    if [[ "$TASK" != "STS-B" ]]
    then
        fairseq-preprocess \
            --only-source \
            --trainpref "$TASK_DATA_FOLDER/processed/train_label.tsv" \
            --validpref "${DEVPREF//LANG/label.tsv}" \
            --destdir "$TASK-bin/label" \
            --workers 60;
    else
        # For STS-B output range is converted to be between: [0.0, 1.0]
        awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train_label.tsv" > "$TASK-bin/label/train_label.tsv"
        awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev_label.tsv" > "$TASK-bin/label/valid_label.tsv" 
    fi
done

处理结果:
在这里插入图片描述

三. 使用预处理好的数据集进行 finetune

3.1 将RoBERTa的模型下载到本地

这里我使用base模型来做例子.
在这里插入图片描述

3.2 微调任务之RTE(句子二分类任务)

  • 需要涉及到的文件
      1. Pretrain的model(.pt)文件
      1. 指令: fairseq-hydra-train
      1. 对应下游任务的yaml文件, 这里是RTE.yaml
      1. 微调的数据集文件(文件夹), RTE-bin.
      1. 指定需要存储的checkpoint路径文件.(一般与model文件是一样的).
  • shell代码:
#!/bin/bash

ROBERTA_PATH=/home/phac123/LLM/RoBERTa/fine_tune_demo1_RTE/base/model.pt

CUDA_VISIBLE_DEVICES=1 fairseq-hydra-train --config-dir /home/phac123/LLM/RoBERTa/fine_tune_demo1_RTE/ --config-name rte \
task.data=/home/phac123/LLM/RoBERTa/fine_tune_demo1_RTE/RTE-bin checkpoint.restore_file=$ROBERTA_PATH

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值