目的
本文主要介绍如何实现fastNLP 来复现今年发表在顶会的一篇论文BertGCN: Transductive Text Classification by Combining GCN and BERT。
FastNLP配置
本文采用的fastNLP版本号为0.6.0,可采用一下命令来安装
pip install -b dev https://github.com/fastnlp/fastNLP.git
python setup.py build
python setup.py install
数据预处理
论文采用的架构是bert和gcn,故数据集需要分两步来处理,第一步是将数据集处理成一张图,得到图的邻接矩阵,第二步将其处理成适应bert输入的序列形式。由于FastNLP封装了论文中采用的5个数据集的loader函数和PMIBuildGraph函数,故数据处理的代码变得很简单,具体如下:
class PrepareData:
def __init__(self, args):
self.arg = args
self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)
if self.arg.dataset == 'mr':
data_bundle, adj, target_vocab = self._get_input(MRLoader, MRPmiGraphPipe, args.dev_ratio)
elif self.arg.dataset == 'R8':
data_bundle, adj, target_vocab = self._get_input(R8Loader, R8PmiGraphPipe, args.dev_ratio)
elif self.arg.dataset == 'R52':
data_bundle, adj, target_vocab = self._get_input(R52Loader, R52PmiGraphPipe, args.dev_ratio)
elif self.arg.dataset == 'ohsumed':
data_bundle, adj, target_vocab = self._get_input(OhsumedLoader, OhsumedPmiGraphPipe, args.dev_ratio)
elif self.arg.dataset == '20ng':
data_bundle, adj, target_vocab = self._get_input(NG20Loader, NG20PmiGraphPipe, args.dev_ratio)
else:
raise RuntimeError('输入数据集错误,请更改为["mr", "R8", "R52", "ohsumed", "20ng"]')
self.data_bundle = data_bundle
self.target_vocab = target_vocab
## 论文中的memory bank实现形式
feats = th.FloatTensor(th.randn((adj.shape[0], args.embed_size)))
self.graph_info = {
"adj": adj, "feats": feats}
def _get_input(self, loader:loader, buildGraph, dev_ratio=0.2):
##加载数据集
load, bg = loader(), buildGraph()
data_bundle = load.load(load.download(dev_ratio=dev_ratio))
adj, index = bg.build_graph(data_bundle)
## 添加doc标签,以便于在图中定位文档的位置
data_bundle.get_dataset