BertGCN的fastNLP实现

目的

本文主要介绍如何实现fastNLP 来复现今年发表在顶会的一篇论文BertGCN: Transductive Text Classification by Combining GCN and BERT。

FastNLP配置

本文采用的fastNLP版本号为0.6.0,可采用一下命令来安装

pip install -b dev https://github.com/fastnlp/fastNLP.git
python setup.py build
python setup.py install

数据预处理

论文采用的架构是bert和gcn,故数据集需要分两步来处理,第一步是将数据集处理成一张图,得到图的邻接矩阵,第二步将其处理成适应bert输入的序列形式。由于FastNLP封装了论文中采用的5个数据集的loader函数和PMIBuildGraph函数,故数据处理的代码变得很简单,具体如下:

class PrepareData:
    def __init__(self, args):
        self.arg = args
        self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)
        if self.arg.dataset == 'mr':
            data_bundle, adj, target_vocab = self._get_input(MRLoader, MRPmiGraphPipe,  args.dev_ratio)
        elif self.arg.dataset == 'R8':
            data_bundle, adj, target_vocab = self._get_input(R8Loader, R8PmiGraphPipe,  args.dev_ratio)
        elif self.arg.dataset == 'R52':
            data_bundle, adj, target_vocab = self._get_input(R52Loader, R52PmiGraphPipe,  args.dev_ratio)
        elif self.arg.dataset == 'ohsumed':
            data_bundle, adj, target_vocab = self._get_input(OhsumedLoader, OhsumedPmiGraphPipe,  args.dev_ratio)
        elif self.arg.dataset == '20ng':
            data_bundle, adj, target_vocab = self._get_input(NG20Loader, NG20PmiGraphPipe,  args.dev_ratio)
        else:
            raise RuntimeError('输入数据集错误,请更改为["mr", "R8", "R52", "ohsumed", "20ng"]')

        self.data_bundle = data_bundle
        self.target_vocab = target_vocab
        ## 论文中的memory bank实现形式
        feats = th.FloatTensor(th.randn((adj.shape[0], args.embed_size)))
        self.graph_info = {
   "adj": adj, "feats": feats}

    def _get_input(self, loader:loader, buildGraph, dev_ratio=0.2):
        ##加载数据集
        load, bg = loader(), buildGraph()
        data_bundle = load.load(load.download(dev_ratio=dev_ratio))
        adj, index = bg.build_graph(data_bundle)
        ## 添加doc标签,以便于在图中定位文档的位置
        data_bundle.get_dataset
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值