fastNLP框架实现NER

1.argparse四步走

// An highlighted block
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str)
args = parser.parse_args()

2.设置对应数据集参数

// An highlighted block
if dataset == 'WB':
    n_heads = 2
    head_dims = 128
    num_layers = 2
    lr = 0.0007
    attn_type = 'adatrans'
    n_epochs = 50

3.cache缓存数据

// An highlighted block
from fastNLP import cache_results
name = 'caches/{}_{}_{}_{}_{}_{}_{}.pkl'.format(dataset, model_type, encoding_type, normalize_embed, args.pos_th,
                                                args.dep_th, args.chunk_th)
@cache_results(name, _refresh=False)

4.加载数据

// An highlighted block
from fastNLP.embeddings import StaticEmbedding, BertEmbedding, StackEmbedding
from modules.pipe import CNNERPipe

def load_data():
    paths = {
        'train':'data/{}/train.txt'.format(dataset),
        "dev":'data/{}/dev.txt'.format(dataset),
        "test":'data/{}/test.txt'.format(dataset)
    }
    data_bundle = CNNERPipe(bigrams=True, trigrams=True, encoding_type=encoding_type).process_from_file(paths)
    embed = StaticEmbedding(data_bundle.get_vocab('chars'),
                            model_dir_or_name='data/gigaword_chn.all.a2b.uni.ite50.vec',
                            min_freq=1, only_norm_found_vector=normalize_embed, word_dropout=0.01, dropout=0.3)

    bert_embed = BertEmbedding(vocab=data_bundle.get_vocab('chars'), model_dir_or_name=args.bert_model, layers='-1',
                               pool_method=args.pool_method, word_dropout=0, dropout=0.5, include_cls_sep=False,
                               pooled_cls=True, requires_grad=False, auto_truncate=False)

    embed = StackEmbedding([embed, tencent_embed, bert_embed], dropout=0, word_dropout=0.02)

5.导入模型

// An highlighted block
model = TENER(tag_vocab=data_bundle.get_vocab('target'), embed=embed, num_layers=num_layers,
              d_model=d_model, n_head=n_heads,
              feedforward_dim=dim_feedforward, dropout=args.trans_dropout,
              after_norm=after_norm, attn_type=attn_type,
              bi_embed=bi_embed,
              bi_embed2=bi_embed2,
              tri_embed = tri_embed,
              tri_embed2 = tri_embed2,
              tri_embed3 = tri_embed3,
              fc_dropout=fc_dropout,
              pos_embed=pos_embed,
              scale=attn_type=='transformer',
              use_knowledge=True,
              feature2count=feature2count,
              vocab_size=vocab_size,
              feature_vocab_size=feature_vocab_size,
              kv_attn_type=args.kv_attn_type,
              memory_dropout=args.memory_dropout,
              fusion_dropout=args.fusion_dropout,
              fusion_type=args.fusion_type,
              highway_layer=args.highway_layer,
              key_embed_dropout=args.key_embed_dropout,
              knowledge_type=args.knowledge_type,
              use_zen=args.zen_model!=""
              )

6.设置优化器

// An highlighted block
from torch import optim
if args.optim_type == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
else:
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99))

7.Callbacks的使用

// An highlighted block
from fastNLP import Trainer, GradientClipCallback, WarmupCallback
callbacks = []
clip_callback = GradientClipCallback(clip_type='value', clip_value=5)
evaluate_callback = EvaluateCallback(data=data_bundle.get_dataset('test'),
                                     use_knowledge=True,
                                     knowledge_type=args.knowledge_type,
                                     pos_th=args.pos_th,
                                     dep_th=args.dep_th,
                                     chunk_th=args.chunk_th,
                                     test_feature_data=test_feature_data,
                                     feature2count=feature2count,
                                     feature2id=feature2id,
                                     id2feature=id2feature,
                                    #  use_zen=args.zen_model!="",
                                    #  zen_model=zen_model,
                                    #  zen_dataset=zen_test_dataset
                                     )
if warmup_steps>0:
    warmup_callback = WarmupCallback(warmup_steps, schedule='linear')
    callbacks.append(warmup_callback)
callbacks.extend([clip_callback, evaluate_callback])

8.Trainer

// An highlighted block
from fastNLP import SpanFPreRecMetric, BucketSampler
trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer, batch_size=batch_size, sampler=BucketSampler(),
                  num_workers=0, n_epochs=100, dev_data=data_bundle.get_dataset('dev'),
                  metrics=SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type=encoding_type),
                  dev_batch_size=batch_size, callbacks=callbacks, device=device, test_use_tqdm=False,
                  use_tqdm=True, print_every=300, save_path=save_path,
                  use_knowledge=True,
                  knowledge_type=args.knowledge_type,
                  pos_th=args.pos_th,
                  dep_th=args.dep_th,
                  chunk_th=args.chunk_th,
                  train_feature_data=train_feature_data,
                  test_feature_data=dev_feature_data,
                  feature2count=feature2count,
                  feature2id=feature2id,
                  id2feature=id2feature,
                  logger_func=write_log,
                  use_zen=args.zen_model!="",
                  zen_model=zen_model,
                  zen_train_dataset=zen_train_dataset,
                  zen_dev_dataset=zen_dev_dataset
                  )

trainer.train(load_best_model=False)
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值