使用faiseq实现分类任务
https://fairseq.readthedocs.io/en/latest/tutorial_classifying_names.html
1、预处理数据以创建字典
(1)下载数据集:tutorial_names.tar.gz
https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz
(2)使用fairseq预处理
fairseq-preprocess --trainpref names/train --validpref names/valid --testpref names/test --source-lang input --target-lang label --destdir names-bin --dataset-impl raw
生成name-bin文件夹包含输入和标签的字典
2、注册模型
在fairseq/models下面新建一个py文件rnn_classifier.py
import torch
import torch.nn as nn
from fairseq.models import register_model_architecture
from fairseq.models import BaseFairseqModel, register_model
@register_model('rnn_classifier')
class FairseqRNNClassifier(BaseFairseqModel):
@staticmethod
def add_args(parser):
# Models can override this method to add new command-line arguments.
# Here we'll add a new command-line argument to configure the
# dimensionality of the hidden state.
parser.add_argument(
'--hidden-dim', type=int, metavar='N',
help='dimensionality of the hidden state',
)
@classmethod
def build_model(cls, args, task):
# Initialize our RNN module
rnn = RNN(
# We'll define the Task in the next section, but for now just