【FEW-NERD: A Few-shot Named Entity Recognition Dataset 官方代码解读(一)代码框架】

一、 背景

  • 命名实体识别是知识图谱构建过程中的重要一环,但由于命名实体源数据在获取其标签过程中需要花费较大代价,因此引入小样本命名实体识别,目的是提升标签缺失情况下对命名实体进行分类时的准确性。
  • 传统数据集适用于有标签模式下进行命名实体识别,并不是专门针对小样本研究而设计的,因此2021年清华大学和阿里团队发布Few-NERD数据集。

Few-NERD数据集介绍

此数据集专门针对小样本而设计,作者将其发表在 https://ningding97.github.io/fewnerd/

如上图所示,一共包含8个粗粒度实体类型和66个细粒度实体类型。数据集采集源域维基百科,语料丰富。
三个数据集可适用于不同模式下的命名实体识别。

  1. supervised:数据集按照7:2:1的比例被随即划分为训练集、验证集和测试集,三个集合都包含66个细粒度实体类型(不适用于小样本研究);
  2. INTRA:按照粗粒度实体进行分类,即训练集:People, MISC, Art, Product,验证集:Event, Building,测试集:ORG, LOC;
  3. INTER:按照细粒度进行划分,在每个粗粒度中,均随机挑选60%的细粒度实体类作为训练集,同理,每个粗粒度类中随机挑选20%、20%作为验证集和测试集;

二、 官方代码解读

作者将代码开源于https://github.com/thunlp/Few-NERD
首先本文章将会对该官方代码的整体框架进行解读,后续将发布对各模块进行再解读的文章。一则是为了方便自己掌握,二则有解读错误的地方欢迎大家进行指出。

(一)代码总模块

代码主要分成data、model、util、和主函数这四个模块

  1. data:获取数据以及存放所获得的数据集
  2. model:包含nnshot和proto两个模型
  3. util:所用到的一些工具模块
  4. 主函数:train_demo 为小样本模式下运行的主函数、run_supervised为有监督模式(非小样本)运行的主函数。
    解读的是小样本命名实体识别的相关代码,因此此文章解读的是train_demo.py,即小样本实体识别的整体框架

(二)train_demo.py

train_demo.py是运行的主函数,因此在此能展现其他各个模块之间的联系,此py文件中用到的其他模块以及对应的函数由import中可以看出,如下所示:
在这里插入图片描述
下面按顺序解析各个模块的内容。

1. 使用argparse批量定义和设置默认参数

'''
    ‘’内为参数名,可模糊匹配,
    help为参数解释帮助信息,
    default为默认设置,
    type为参数类型
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='inter',
            help='training mode, must be in [inter, intra]')
    parser.add_argument('--trainN', default=2, type=int,
            help='N in train')
    parser.add_argument('--N', default=2, type=int,
            help='N way')
    parser.add_argument('--K', default=2, type=int,
            help='K shot')
    parser.add_argument('--Q', default=3, type=int,
            help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int,
            help='batch size') #batch size是一次训练所抓取的样本数量
    parser.add_argument('--train_iter', default=600, type=int,
            help='num of iters in training')
    parser.add_argument('--val_iter', default=100, type=int,
            help='num of iters in validation')
    parser.add_argument('--test_iter', default=500, type=int,
            help='num of iters in testing')
    parser.add_argument('--val_step', default=20, type=int,
           help='val after training how many iters')
    parser.add_argument('--model', default='proto',
            help='model name, must be proto, nnshot, or structshot')
    parser.add_argument('--max_length', default=100, type=int,
           help='max length')
    parser.add_argument('--lr', default=1e-4, type=float,
           help='learning rate')
    parser.add_argument('--grad_iter', default=1, type=int,
           help='accumulate gradient every x iterations')
    parser.add_argument('--load_ckpt', default=None,
           help='load ckpt')
    parser.add_argument('--save_ckpt', default=None,
           help='save ckpt')
    parser.add_argument('--fp16', action='store_true',
           help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true',
           help='only test') #表示没有训练过程,只有测试
    parser.add_argument('--ckpt_name', type=str, default='',
           help='checkpoint name.')
    parser.add_argument('--seed', type=int, default=0,
           help='random seed')
    parser.add_argument('--ignore_index', type=int, default=-1,
           help='label index to ignore when calculating loss and metrics')
    parser.add_argument('--use_sampled_data', action='store_true',
           help='use released sampled data, the data should be stored at "data/episode-data/" ')
    #使用已发布的已采集数据

    # only for bert / roberta
    parser.add_argument('--pretrain_ckpt', default=None,
           help='bert / roberta pre-trained checkpoint')

    # only for prototypical networks
    parser.add_argument('--dot', action='store_true', 
           help='use dot instead of L2 distance for proto')

    # only for structshot
    parser.add_argument('--tau', default=0.05, type=float,
           help='StructShot parameter to re-normalizes the transition probabilities')

    # experiment
    parser.add_argument('--use_sgd_for_bert', action='store_true',
           help='use SGD instead of AdamW for BERT.')

action的默认设置是不读取参数,action设置为store_true时如果直接运行,该参数值为false,有关action的介绍,可参考这篇博文python中argparse模块,action=‘store_true‘

2. 设置实验,初始化模型参数

#设置实验,初始化模型参数
    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    max_length = opt.max_length

    #根据上述的参数默认设置,所作的实验为2way-2shot
    print("{}-way-{}-shot Few-Shot NER".format(N, K))
    print("model: {}".format(model_name)) #model默认为proto
    print("max_length: {}".format(max_length)) #max_size为100
    print('mode: {}'.format(opt.mode))  #默认使用inter数据集

    set_seed(opt.seed)
    print('loading model and tokenizer...')
    pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'
    word_encoder = BERTWordEncoder(
            pretrain_ckpt)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

按照1中的默认参数设置,源代码默认实验设置为使用inter数据,用proto模型进行实验,遵循原文方法Nway-K~2Kshot进行使用自采样(非加载原作者已采样好的数据——便于当作baseline)。

3. 加载数据

由于前面use_sampled_data在设置参数的时候action设置为store_true,因此参数值默认为False,下面函数执行的时if代码块而非else。

print('loading data...')
    if not opt.use_sampled_data:
        #存储train、test、dev的地址
        opt.train = f'data/{opt.mode}/train.txt'
        opt.test = f'data/{opt.mode}/test.txt'
        opt.dev = f'data/{opt.mode}/dev.txt'
        if not (os.path.exists(opt.train) and os.path.exists(opt.dev) and os.path.exists(opt.test)):
            os.system(f'bash data/download.sh {opt.mode}')
    else:
        opt.train = f'data/episode-data/{opt.mode}/train_{opt.N}_{opt.K}.jsonl'
        opt.test = f'data/episode-data/{opt.mode}/test_{opt.N}_{opt.K}.jsonl'
        opt.dev = f'data/episode-data/{opt.mode}/dev_{opt.N}_{opt.K}.jsonl'
        if not (os.path.exists(opt.train) and os.path.exists(opt.dev) and os.path.exists(opt.test)):
            os.system(f'bash data/download.sh episode-data')
            os.system('unzip -d data/ data/episode-data.zip')
    
    if opt.mode == "supervised":
        print("Warning: you are running few-shot learning methods on `supervised` dataset, if it is not expected, please change to `--mode inter` or `--mode intra`.")

采集好数据集的地址后, 使用util/data_loader.py中国的get_loader()函数进行数据载入

    #使用util/data_loader.py中国的get_loader()函数进行数据加载
    train_data_loader = get_loader(opt.train, tokenizer,
            N=trainN, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data)
    val_data_loader = get_loader(opt.dev, tokenizer,
            N=N, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data)
    test_data_loader = get_loader(opt.test, tokenizer,
            N=N, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data)

4.载入模型

载入数据后,将其喂入模型中,本论文一共于三个模型:proto、nnshot和structshot,默认使用proto;
并且使用了util/framework中的FewShotNERFramework()函数

#载入模型
    if model_name == 'proto':
        print('use proto')
        model = Proto(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
        framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader, use_sampled_data=opt.use_sampled_data)
    elif model_name == 'nnshot':
        print('use nnshot')
        model = NNShot(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
        framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader, use_sampled_data=opt.use_sampled_data)
    elif model_name == 'structshot':
        print('use structshot')
        model = NNShot(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
        framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader, N=opt.N, tau=opt.tau, train_fname=opt.train, viterbi=True, use_sampled_data=opt.use_sampled_data)
    else:
        raise NotImplementedError

   
    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt
    print('model-save-path:', ckpt)

    if torch.cuda.is_available():
        model.cuda()

5.模型训练

模型训练中提供了两种模式,一个是训练+测试,另一种是only_test,1中也将only_test中的action设置为store_true。因此直接运行默认其值为false。

#模型训练

    if not opt.only_test:
        if opt.lr == -1:
            opt.lr = 2e-5

        framework.train(model, prefix,
                load_ckpt=opt.load_ckpt, save_ckpt=ckpt,
                val_step=opt.val_step, fp16=opt.fp16,
                train_iter=opt.train_iter, warmup_step=int(opt.train_iter * 0.1), val_iter=opt.val_iter, learning_rate=opt.lr, use_sgd_for_bert=opt.use_sgd_for_bert)
    else:
        ckpt = opt.load_ckp
        if ckpt is None :
            print("Warning: --load_ckpt is not specified. Will load Hugginface pre-trained checkpoint.")
            ckpt = 'none'

6. 模型测试

进行模型测试并进行精确度等的计算。

# test
    precision, recall, f1, fp, fn, within, outer = framework.eval(model, opt.test_iter, ckpt=ckpt)
    print("RESULT: precision: %.4f, recall: %.4f, f1:%.4f" % (precision, recall, f1))
    print('ERROR ANALYSIS: fp: %.4f, fn: %.4f, within:%.4f, outer: %.4f'%(fp, fn, within, outer))

总结

在小样本实体识别研究中这篇论文提供了proto、nnshot、structshot三种经典方法的baseline,学习使用这三种方法对今后的小样本NER具有重要意义。解析有不对的地方欢迎指出,共同学习,共同进步。
今后的解析将按照主函数所涉及的各个模块的先后顺序进行,后续先介绍使用原形网络进行实体抽取。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值