论文地址:DAGA: Data Augmentation with a Generation Approach for Low-resource Tagging Tasks
github地址:https://github.com/ntunlp/daga
message: EMNLP2020;南洋理工&阿里达摩院
简介:
在数据增强方面,传统的EDA(Easy Data Augmentation)由四个简单但功能强大的操作组成:同义词替换、随机插入、随机交换和随机删除。DAGA目的是为低资源标记任务生成高质量的合成数据而提出了一种新的增强方法。
DAGA即相应代码主要针对命名实体识别(NER)任务和序列标注任务(POS)做了相关工作,本博客主要针对NER的相关内容以及复现中存在的问题做简单记录。在NER方面,DAGA基于LSTM研究数据增强方法,首先将标注数据线性化,将带标签的句子转换成线性序列,通过语言模型学习标注数据中单词和标签的分布情况。具体的,将标签插入到对应单词的前面,作为这些词的修饰语。同时提供相应处理策略:
- 频繁出现O,从线性化序列中移除此类标记。
- 数字,用N代替。
于是形成新的句子,如下图所示:
然后对线性化后的数据训练语言模型(language model),用于生成合成标记数据。
复现
下载github代码,NER主要使用到的代码是lstm-lm、tools等文件
环境
RTX2060
CUDA:10.2
python:3.6
torch:1.6.0
torchtext:0.6.0
conda create -n daga python=3.6 # 新建daga环境
conda install cudatoolkit=10.2 pytorch=1.6 -c pytorch # conda安装torch1.6
pip install torchtext==0.6.0
数据准备
我使用的是标准的NER数据集格式,词-标签,注意:中间分隔符为\t
宁 B-LOC
波 I-LOC
市 I-LOC
标注数据线性化
cd tools
python preprocess.py --train_file [训练集路径] --test_file [测试集路径] --dev_file [验证集路径] --vocab_size [自定义一个词表长度,我设置了10000]
最终生成带有lin的文件在tools中。
训练语言模型
python train.py --train_file [上一步生成的train.lin.txt] --valid_file [上一步生成的test.lin.txt] --model_file [模型保存的路径/model.pt] --emb_dim 300 --rnn_size 512 --gpuid 0
数据生成
python generate.py --model_file [上一步中模型路径/model.pt] --out_file [输出文件路径/out.txt] --num_sentences 10000 --temperature 1.0 --seed 3435 --max_sent_length 32 --gpuid 0
数据还原
python line2cols.py --inp_file [数据生成得到的结果文件路径/out.txt] --out_file [数据增强结果路径/data_augmentation.txt]
遇到的错误
问题1: RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
RTX3080无法兼容低版本pytorch,建议更换服务器
问题2: AttributeError: ‘int’ object has no attribute ‘fields’
忘记了,好像也是版本问题。。。
问题3 :OSError: _torchtext.so: undefined symbol:
pytorch版本与torchtext版本不对应
网址:torchtext版本对应
问题4 :RuntimeError: CUDA error: no kernel image is available for execution on the device
cuda版本过高,安装低版本cudatoolkit
conda install cudatoolkit=10.2 pytorch=1.6 -c pytorch
问题5: AttributeError: module ‘torchtext.data’ has no attribute ‘iterator’
需要安装torchtext0.9之前的版本,torchtext0.9改成了torchtext.legacy.data.Iterator,torchtext0.12取消了对象的引用,所以最好使用低版本的torchtext
pip install torchtext==0.6.0