HuggingFace TAPAS Model细节记录

HuggingFace TAPAS Model细节记录:
目前google的TAPAS model在huggingface上都能找到各种预训练model和训练好的model,所以这里就做一下model的尝试,以及finetune过程。

首先看一下最为常用的训练好的model:
hugging face TableQA model
这里面有大量的model,举例一下常用的:

google/tapas-base-finetuned-wtq

(download 19.2k)model
那这个和基础的TAPAS原论文model有啥区别呢?

  • version
    首先是这个model有两个版本,区别在于position-embedding,一个用的是相对位置索引(在表格的每个cell开始时重置一下position index,默认是使用这个,效果会好一点),另一个是绝对位置索引。

  • intermediate pre-training (不感兴趣的可以忽略这个)
    这个是训练模型的一部分,在基础的TAPAS的pre-training和fintune之间增加的一步,这个很多huggingface上很多基于tapas的model都用了。下面用的都是wikipedia的table和相关text,wikipedia的table保留至少两行两列(header加上就是三行),递归地将表格逐行拆分为上半部分和下半部分,直到它们最多有 50 个单元格。这样我们就获得了 370 万张表。

    具体来说,又分为两步:

    • Synthetic Statements
      这部分是在wikipedia的table上合成句子和SQL,用于提高模型的数值运算处理和比较的能力。实际就是按照一定的模版生成与table相关的句子以及SQL,而且相关的句子以一定的概率进行扰乱。句子和SQL的语法规则如下:
      在这里插入图片描述
      <>尖括号包裹的是SQL语句,其余的是生成句子的短语,按照这种语法规则,生成的句子以及对应的SQL,并且在正常的句子的前面或者后面再加一层数字对比短语,这样的句子可能不会通顺,但是没关系,比如:
      example table:

      Rank Player Country Earnings Events Wins
      1 Greg Norman Australia 1,654,959 16 3
      2 Billy Mayfair United States 1,543,192 28 2
      3 Lee Janzen United States 1,378,966 28 3
      4 Corey Pavin United States 1,340,079 22 2
      5 Steve Elkington Australia 1,254,352 21 2 Australia 1,254,352 21 2

      合成句子:2 is less than wins when Player is Lee Janzen.
      SQL:SELECT wins FROM table WHERE player = “Lee Janzen”
      结果:通过SQL产生的结果为3,2 is less than 3这是正确的,positive
      负例句子:3 is less than wins when Player is Lee Janzen. 这样的就是negative
      像这样的生成了370万对,模型是将生成的句子和表格作为输入,输出是判断正错。

    • Counterfactual Statements
      这里是对wikipedia表附近的文本进行entity replace,比如说,原句是:Greg Norman has the highest earnings,我们替换成:Steve Elkington has the highest earnings.
      这样生成了410万句子,模型任务是判断句子是否被替换。
      可能有些理解不是很正确,感兴趣的可以参考:
      github描述
      论文3.1 和3.2章节

fintune

依次通过 SQA, WikiSQL and finally WTQ.

code test

环境
  • torch.version = ‘1.6.0+cu101’
  • torch-scatter (这个包一定要安装):pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html
    这两个对应上就好,哪个版本无所谓,官网给的是1.8.0+cu101
  • torch-sparse:pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html
tqa = pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq")
table = pd.DataFrame({
   "fund":["CSI 300", "Bank of communications","China Shipping"],
                      "type":["income","hybrid","income"],
                      "annual increase":[0.15,0.18,0.16]})
table = table.astype(str)
table

在这里插入图片描述

query = ["Which funds its type is income?", 
         "Which fund has the highest annual return?",
         "Is there a fund of income type?",
         "What are the funds with an annual increase of more than 0.2?"
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值