以huggingface- transformer中实现的的text-classification为例。
在使用bash运行时,删掉task_name 或是其他可以自动下载数据集的传入参数。
按需增加以下传入参数
(训练/测试/验证集要是.csv或者.json文件)
--train_file 训练集地址 \
--validation_file 验证集地址 \
--test_file 测试集地址 \
--do_train \
--do_eval \
--do_predict \
以下是我的bash:
python3 run_glue.py \
--model_name_or_path bert-base-uncased \
--train_file ./data/train.csv \
--validation_file ./data/dev.csv \
--test_file ./data/test.csv \
--do_train \
--do_eval \
--do_predict \
--max_seq_length 128 \
--per_device_train_batch_size 8 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir /tmp/run_glue/
之后在加载数据的时候就会加载自己的数据了
加载本地数据代码段(transformers代码)
# Loading a dataset from your local files.
# CSV/JSON training and evaluation files are needed.
data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
# Get the test dataset: you can provide your own CSV/JSON test file (see below)
# when you use `do_predict` without specifying a GLUE benchmark task.
if training_args.do_predict:
if data_args.test_file is not None:
train_extension = data_args.train_file.split(".")[-1]
test_extension = data_args.test_file.split(".")[-1]
assert (
test_extension == train_extension
), "`test_file` should have the same extension (csv or json) as `train_file`."
data_files["test"] = data_args.test_file
else:
raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
for key in data_files.keys():
logger.info(f"load a local file for {key}: {data_files[key]}")
if data_args.train_file.endswith(".csv"):
# Loading a dataset from local csv files
raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir)
else:
# Loading a dataset from local json files
raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)